From 7998b9d804bb8ce28d777e44392bc12b33e30d7a Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 23 Jan 2025 15:24:59 -0500 Subject: [PATCH 01/32] tmp --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 17 +- orig.mlir | 381 ++++++++++++++++++ python/tutorials/09-persistent-matmul.py | 166 ++++---- test.mlir | 177 ++++++++ test1.mlir | 149 +++++++ 5 files changed, 796 insertions(+), 94 deletions(-) create mode 100644 orig.mlir create mode 100644 test.mlir create mode 100644 test1.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index a74d274d386e..ecba8c8582d7 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -1,6 +1,7 @@ #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Dominance.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "llvm/Support/Debug.h" #include @@ -778,14 +779,28 @@ static void flattenLoopNest(LoopNestNode *node, mlir::DominanceInfo &domInfo) { // Pass Implementation //===----------------------------------------------------------------------===// +// Fuse simple loop nests with a single outer and inner loop, and where the +// inner loop has a `tt.dot` operation. +static bool shouldFuse(const LoopNest &nest) { + if (nest.nodes.size() != 2 || nest.root->children.size() != 1) + return false; + + scf::ForOp innerLoop = nest.root->children.front()->loop; + return llvm::any_of(innerLoop.getOps(), + [](Operation &op) { return isa(op); }); +} + void FuseNestedLoopsPass::runOnOperation() { auto &domInfo = getAnalysis(); for (auto func : getOperation().getOps()) { SmallVector nests; findLoopNests(func, nests); - for (LoopNest &nest : nests) + for (LoopNest &nest : nests) { + if (!shouldFuse(nest)) + continue; flattenLoopNest(nest.root, domInfo); + } } } diff --git a/orig.mlir b/orig.mlir new file mode 100644 index 000000000000..c6f2580f8198 --- /dev/null +++ b/orig.mlir @@ -0,0 +1,381 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c3_i32 = arith.constant 3 : i32 loc(#loc1) + %false = arith.constant false loc(#loc1) + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %c-1_i32 = arith.constant -1 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c132_i32 = arith.constant 132 : i32 loc(#loc1) + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %c255_i32 = arith.constant 255 : i32 loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc80) + %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc81) + %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc82) + %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc83) + %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc84) + %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc85) + %7 = arith.muli %2, %4 : i32 loc(#loc8) + %8 = arith.divsi %7, %c132_i32 : i32 loc(#loc9) + %9 = arith.remsi %7, %c132_i32 : i32 loc(#loc10) + %10 = arith.cmpi slt, %0, %9 : i32 loc(#loc11) + %11 = scf.if %10 -> (i32) { + %122 = arith.addi %8, %c1_i32 : i32 loc(#loc13) + scf.yield %122 : i32 loc(#loc13) + } else { + scf.yield %8 : i32 loc(#loc1) + } loc(#loc12) + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc14) + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc14) + %14 = arith.muli %4, %c8_i32 : i32 loc(#loc15) + %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc16) + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc16) + %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc17) + %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc17) + %19 = arith.muli %6, %11 : i32 loc(#loc18) + %20 = arith.subi %6, %c1_i32 : i32 loc(#loc19) + %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc20) + %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc21) + %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc22) + %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc23) + %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc24) + %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc25) + %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc26) + %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc27) + %29 = arith.cmpi sgt, %19, %c0_i32 : i32 loc(#loc28) + %30 = arith.divsi %0, %14 : i32 loc(#loc29) + %31 = arith.muli %30, %c8_i32 : i32 loc(#loc30) + %32 = arith.subi %2, %31 : i32 loc(#loc31) + %33 = arith.minsi %32, %c8_i32 : i32 loc(#loc32) + %34 = arith.remsi %0, %33 : i32 loc(#loc33) + %35 = arith.addi %31, %34 : i32 loc(#loc34) + %36 = arith.remsi %0, %14 : i32 loc(#loc35) + %37 = arith.divsi %36, %33 : i32 loc(#loc36) + %38 = arith.muli %35, %c128_i32 : i32 loc(#loc37) + %39 = arith.muli %37, %c256_i32 : i32 loc(#loc38) + %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) + %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) + %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) + %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> loc(#loc46) + %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) + %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) + %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> loc(#loc22) + %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> loc(#loc48) + %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) + %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) + %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> loc(#loc49) + %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) + %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) + %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> loc(#loc28) + %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %69 = ttg.async_commit_group %68 loc(#loc26) + %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) + %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> loc(#loc50) + %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) + %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) + %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> loc(#loc28) + %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %77 = ttg.async_commit_group %76 loc(#loc27) + %78 = arith.cmpi sgt, %19, %c1_i32 : i32 loc(#loc28) + %79 = arith.cmpi ne, %20, %c0_i32 : i32 loc(#loc86) + %80 = arith.extui %79 : i1 to i32 loc(#loc51) + %81 = arith.cmpi eq, %80, %c0_i32 : i32 loc(#loc53) + %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %122 = arith.addi %0, %c132_i32 : i32 loc(#loc55) + %123 = arith.divsi %122, %14 : i32 loc(#loc29) + %124 = arith.muli %123, %c8_i32 : i32 loc(#loc30) + %125 = arith.subi %2, %124 : i32 loc(#loc31) + %126 = arith.minsi %125, %c8_i32 : i32 loc(#loc32) + %127 = arith.remsi %122, %126 : i32 loc(#loc33) + %128 = arith.addi %124, %127 : i32 loc(#loc34) + %129 = arith.remsi %122, %14 : i32 loc(#loc35) + %130 = arith.divsi %129, %126 : i32 loc(#loc36) + %131 = arith.muli %128, %c128_i32 : i32 loc(#loc37) + %132 = arith.muli %130, %c256_i32 : i32 loc(#loc38) + %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) + %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + } else { + scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + } loc(#loc54) + %83 = arith.muli %80, %c64_i32 : i32 loc(#loc56) + %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) + %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) + %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) + %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> loc(#loc46) + %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) + %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) + %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) + %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> loc(#loc22) + %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> loc(#loc48) + %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) + %102 = arith.subi %arg5, %83 : i32 loc(#loc60) + %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) + %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> loc(#loc49) + %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) + %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) + %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> loc(#loc28) + %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %110 = ttg.async_commit_group %109 loc(#loc26) + %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) + %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> loc(#loc50) + %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) + %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) + %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> loc(#loc28) + %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %118 = ttg.async_commit_group %117 loc(#loc27) + %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { + %122 = arith.subi %19, %c2_i32 : i32 loc(#loc28) + %123 = arith.cmpi slt, %arg9, %122 : i32 loc(#loc28) + %124 = arith.cmpi eq, %arg10, %20 : i32 loc(#loc52) + %125 = arith.addi %arg10, %c1_i32 : i32 loc(#loc61) + %126 = arith.select %124, %c0_i32, %125 : i32 loc(#loc51) + %127 = arith.cmpi eq, %126, %c0_i32 : i32 loc(#loc53) + %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %178 = arith.addi %arg11, %c132_i32 : i32 loc(#loc55) + %179 = arith.divsi %178, %14 : i32 loc(#loc29) + %180 = arith.muli %179, %c8_i32 : i32 loc(#loc30) + %181 = arith.subi %2, %180 : i32 loc(#loc31) + %182 = arith.minsi %181, %c8_i32 : i32 loc(#loc32) + %183 = arith.remsi %178, %182 : i32 loc(#loc33) + %184 = arith.addi %180, %183 : i32 loc(#loc34) + %185 = arith.remsi %178, %14 : i32 loc(#loc35) + %186 = arith.divsi %185, %182 : i32 loc(#loc36) + %187 = arith.muli %184, %c128_i32 : i32 loc(#loc37) + %188 = arith.muli %186, %c256_i32 : i32 loc(#loc38) + %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) + %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + } else { + scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + } loc(#loc54) + %129 = arith.addi %arg19, %c1_i32 : i32 loc(#loc28) + %130 = arith.cmpi slt, %129, %c3_i32 : i32 loc(#loc28) + %131 = arith.select %130, %129, %c0_i32 : i32 loc(#loc28) + %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %133 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc26) + %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc62) + %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) + %137 = arith.addi %arg18, %c1_i32 : i32 loc(#loc28) + %138 = arith.cmpi slt, %137, %c3_i32 : i32 loc(#loc28) + %139 = arith.select %138, %137, %c0_i32 : i32 loc(#loc28) + %140 = arith.muli %126, %c64_i32 : i32 loc(#loc56) + %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) + %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) + %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) + %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> loc(#loc46) + %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) + %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) + %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) + %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> loc(#loc22) + %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> loc(#loc48) + %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) + %159 = arith.subi %arg5, %140 : i32 loc(#loc60) + %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) + %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> loc(#loc49) + %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) + %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) + %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> loc(#loc28) + %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %167 = ttg.async_commit_group %166 loc(#loc26) + %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) + %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> loc(#loc50) + %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) + %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) + %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> loc(#loc28) + %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %175 = ttg.async_commit_group %174 loc(#loc27) + %176 = arith.cmpi eq, %arg22, %20 : i32 loc(#loc63) + %177 = arith.cmpi ne, %arg22, %20 : i32 loc(#loc87) + scf.if %176 { + %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) + %179 = arith.muli %arg24, %c128_i32 : i32 loc(#loc65) + %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) + %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) + %182 = arith.muli %arg26, %c256_i32 : i32 loc(#loc67) + %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) + %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) + %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc69) + %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc70) + %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> loc(#loc70) + %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc71) + %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc71) + %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc72) + %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc73) + %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc73) + %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc73) + %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc74) + %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> loc(#loc74) + %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc75) + %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> loc(#loc75) + %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) + %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) + %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> loc(#loc76) + %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc77) + %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc78) + tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc78) + } loc(#loc64) + scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 loc(#loc28) + } loc(#loc28) + %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc28) + %121 = ttg.async_wait {num = 0 : i32} loc(#loc28) + ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc28) + ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc28) + tt.return loc(#loc79) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) +#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) +#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) +#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) +#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) +#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) +#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":172:32) +#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:31) +#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:19) +#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:7) +#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:24) +#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:35) +#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":181:38) +#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:27) +#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:27) +#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:32) +#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:38) +#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:45) +#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:26) +#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:75) +#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:26) +#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:49) +#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:49) +#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:20) +#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:20) +#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:22) +#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:34) +#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) +#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:43) +#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:56) +#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:45) +#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:35) +#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:31) +#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:52) +#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":200:30) +#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":201:30) +#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":202:32) +#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:32) +#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:41) +#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:53) +#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:41) +#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:53) +#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:34) +#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:57) +#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:64) +#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:56) +#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:60) +#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:60) +#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:44) +#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:28) +#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:17) +#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:11) +#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:23) +#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:22) +#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:37) +#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:64) +#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:33) +#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:64) +#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:49) +#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:35) +#loc63 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:17) +#loc64 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:11) +#loc65 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:30) +#loc66 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:45) +#loc67 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:30) +#loc68 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:45) +#loc69 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:49) +#loc70 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:41) +#loc71 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:29) +#loc72 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:80) +#loc73 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:60) +#loc74 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:41) +#loc75 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:66) +#loc76 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:47) +#loc77 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":224:35) +#loc78 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":225:29) +#loc79 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:4) +#loc80 = loc(callsite(#loc3 at #loc4)) +#loc81 = loc(callsite(#loc5 at #loc4)) +#loc82 = loc(callsite(#loc3 at #loc6)) +#loc83 = loc(callsite(#loc5 at #loc6)) +#loc84 = loc(callsite(#loc3 at #loc7)) +#loc85 = loc(callsite(#loc5 at #loc7)) +#loc86 = loc(fused[#loc51, #loc52]) +#loc87 = loc(fused[#loc64, #loc63]) + diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index eec0c6248c0f..0d776ba0f3fd 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -168,62 +168,45 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - pid_m = 0 - pid_n = 0 - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - start_m = pid_m * BLOCK_SIZE_M - start_n = pid_n * BLOCK_SIZE_N - offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) - offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) - offs_am = tl.where(offs_am < M, offs_am, 0) - offs_bn = tl.where(offs_bn < N, offs_bn, 0) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) - offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, b, accumulator) + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) - if ki == k_tiles - 1: - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if (c_ptr.dtype.element_ty == tl.float8e4nv): - c = accumulator.to(tl.float8e4nv) - else: - c = accumulator.to(tl.float16) - tl.store(c_ptrs, c, mask=c_mask) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for tile_id in range(start_pid, num_tiles, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if (c_ptr.dtype.element_ty == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) def matmul_persistent(a, b): @@ -261,6 +244,22 @@ def matmul_persistent(a, b): num_stages=configs[dtype]["num_stages"], # num_warps=configs[dtype]["num_warps"], # ) + kernel = matmul_kernel_persistent.warmup( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + NUM_SMS=NUM_SMS, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + grid=grid + ) + print(kernel.asm["ttir"]) return c @@ -279,47 +278,28 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - num_pid_in_group = GROUP_SIZE_M * num_pid_n - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for tile_id in range(start_pid, num_tiles, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N - offs_am = pid_m * BLOCK_SIZE_M - offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K - offs_k = ki * BLOCK_SIZE_K - - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) - accumulator = tl.dot(a, b.T, accumulator) + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + accumulator = tl.dot(a, b.T, accumulator) - if ki == k_tiles - 1: - c = accumulator.to(dtype) - - tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + c = accumulator.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) def matmul_tma_persistent(a, b): @@ -630,8 +610,8 @@ def show_profile(precision, profile_name): validate(32, 32, 32, dtype) validate(8192, 8192, 512, dtype) - proton.start("matmul", hook="triton") - for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench(K, dtype) - proton.finalize() - show_profile(args.prec, "matmul") + #proton.start("matmul", hook="triton") + #for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + # bench(K, dtype) + #proton.finalize() + #show_profile(args.prec, "matmul") diff --git a/test.mlir b/test.mlir new file mode 100644 index 000000000000..345d793c2124 --- /dev/null +++ b/test.mlir @@ -0,0 +1,177 @@ +#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) +module { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32> loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %c255_i32 = arith.constant 255 : i32 loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> loc(#loc1) + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c132_i32 = arith.constant 132 : i32 loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %cst_2 = arith.constant dense<0> : tensor<256xi32> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<128xi32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc63) + %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc64) + %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc65) + %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc66) + %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc67) + %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc68) + %7 = arith.muli %2, %4 : i32 loc(#loc8) + %8 = arith.muli %4, %c8_i32 : i32 loc(#loc9) + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc10) + %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc11) + %11 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc12) + %12 = tt.splat %arg3 : i32 -> tensor<128xi32> loc(#loc13) + %13 = tt.splat %arg4 : i32 -> tensor<256xi32> loc(#loc14) + %14 = tt.splat %arg6 : i32 -> tensor<128x1xi32> loc(#loc15) + %15 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> loc(#loc16) + %16 = tt.splat %arg7 : i32 -> tensor<1x256xi32> loc(#loc17) + %17 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> loc(#loc18) + %18 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc19) + %19 = tt.expand_dims %9 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc20) + %20 = tt.splat %arg8 : i32 -> tensor<128x1xi32> loc(#loc21) + %21 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc22) + %22 = tt.splat %arg3 : i32 -> tensor<128x1xi32> loc(#loc23) + %23 = tt.splat %arg4 : i32 -> tensor<1x256xi32> loc(#loc24) + scf.for %arg9 = %0 to %7 step %c132_i32 : i32 { + %24 = arith.divsi %arg9, %8 : i32 loc(#loc26) + %25 = arith.muli %24, %c8_i32 : i32 loc(#loc27) + %26 = arith.subi %2, %25 : i32 loc(#loc28) + %27 = arith.minsi %26, %c8_i32 : i32 loc(#loc29) + %28 = arith.remsi %arg9, %27 : i32 loc(#loc30) + %29 = arith.addi %25, %28 : i32 loc(#loc31) + %30 = arith.remsi %arg9, %8 : i32 loc(#loc32) + %31 = arith.divsi %30, %27 : i32 loc(#loc33) + %32 = arith.muli %29, %c128_i32 : i32 loc(#loc34) + %33 = arith.muli %31, %c256_i32 : i32 loc(#loc35) + %34 = tt.splat %32 : i32 -> tensor<128xi32> loc(#loc36) + %35 = arith.addi %34, %10 : tensor<128xi32> loc(#loc36) + %36 = tt.splat %33 : i32 -> tensor<256xi32> loc(#loc37) + %37 = arith.addi %36, %11 : tensor<256xi32> loc(#loc37) + %38 = arith.cmpi slt, %35, %12 : tensor<128xi32> loc(#loc13) + %39 = arith.select %38, %35, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> loc(#loc38) + %40 = arith.cmpi slt, %37, %13 : tensor<256xi32> loc(#loc14) + %41 = arith.select %40, %37, %cst_2 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> loc(#loc39) + %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc40) + %43 = arith.muli %42, %14 : tensor<128x1xi32> loc(#loc15) + %44 = tt.broadcast %43 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc41) + %45 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc42) + %46 = arith.muli %45, %16 : tensor<1x256xi32> loc(#loc17) + %47 = tt.broadcast %46 : tensor<1x256xi32> -> tensor<64x256xi32> loc(#loc43) + %48 = scf.for %arg10 = %c0_i32 to %6 step %c1_i32 iter_args(%arg11 = %cst) -> (tensor<128x256xf32>) : i32 { + %62 = arith.muli %arg10, %c64_i32 : i32 loc(#loc45) + %63 = tt.splat %62 : i32 -> tensor<64xi32> loc(#loc46) + %64 = arith.addi %63, %9 : tensor<64xi32> loc(#loc46) + %65 = tt.expand_dims %64 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc47) + %66 = tt.broadcast %65 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc41) + %67 = arith.addi %44, %66 : tensor<128x64xi32> loc(#loc41) + %68 = tt.addptr %15, %67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc16) + %69 = tt.expand_dims %64 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc48) + %70 = tt.broadcast %69 : tensor<64x1xi32> -> tensor<64x256xi32> loc(#loc43) + %71 = arith.addi %70, %47 : tensor<64x256xi32> loc(#loc43) + %72 = tt.addptr %17, %71 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> loc(#loc18) + %73 = arith.subi %arg5, %62 : i32 loc(#loc49) + %74 = tt.splat %73 : i32 -> tensor<1x64xi32> loc(#loc50) + %75 = arith.cmpi slt, %18, %74 : tensor<1x64xi32> loc(#loc50) + %76 = tt.broadcast %75 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc51) + %77 = tt.load %68, %76, %cst_1 : tensor<128x64x!tt.ptr> loc(#loc51) + %78 = tt.splat %73 : i32 -> tensor<64x1xi32> loc(#loc52) + %79 = arith.cmpi slt, %19, %78 : tensor<64x1xi32> loc(#loc52) + %80 = tt.broadcast %79 : tensor<64x1xi1> -> tensor<64x256xi1> loc(#loc53) + %81 = tt.load %72, %80, %cst_0 : tensor<64x256x!tt.ptr> loc(#loc53) + %82 = tt.dot %77, %81, %arg11, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> loc(#loc54) + scf.yield %82 : tensor<128x256xf32> loc(#loc55) + } loc(#loc44) + %49 = tt.expand_dims %35 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc56) + %50 = arith.muli %20, %49 : tensor<128x1xi32> loc(#loc21) + %51 = tt.addptr %21, %50 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc22) + %52 = tt.expand_dims %37 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc57) + %53 = tt.broadcast %51 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> loc(#loc58) + %54 = tt.broadcast %52 : tensor<1x256xi32> -> tensor<128x256xi32> loc(#loc58) + %55 = tt.addptr %53, %54 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> loc(#loc58) + %56 = arith.cmpi slt, %49, %22 : tensor<128x1xi32> loc(#loc23) + %57 = arith.cmpi slt, %52, %23 : tensor<1x256xi32> loc(#loc24) + %58 = tt.broadcast %56 : tensor<128x1xi1> -> tensor<128x256xi1> loc(#loc59) + %59 = tt.broadcast %57 : tensor<1x256xi1> -> tensor<128x256xi1> loc(#loc59) + %60 = arith.andi %58, %59 : tensor<128x256xi1> loc(#loc59) + %61 = arith.truncf %48 : tensor<128x256xf32> to tensor<128x256xf16> loc(#loc60) + tt.store %55, %61, %60 : tensor<128x256x!tt.ptr> loc(#loc61) + } loc(#loc25) + tt.return loc(#loc62) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) +#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) +#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) +#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) +#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) +#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) +#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":171:38) +#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:35) +#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:41) +#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:41) +#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:37) +#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:37) +#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:49) +#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:30) +#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:79) +#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:30) +#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:53) +#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:53) +#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:37) +#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:25) +#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:37) +#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:62) +#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:47) +#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":176:30) +#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":177:33) +#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:39) +#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:52) +#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:41) +#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:31) +#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:27) +#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:48) +#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":182:26) +#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":183:26) +#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:28) +#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:28) +#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:49) +#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:49) +#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:38) +#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:61) +#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:68) +#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:60) +#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:24) +#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:26) +#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:41) +#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:68) +#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) +#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:68) +#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:64) +#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:24) +#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:64) +#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:24) +#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:39) +#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:12) +#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:45) +#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:76) +#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:56) +#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:43) +#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:31) +#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:25) +#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:4) +#loc63 = loc(callsite(#loc3 at #loc4)) +#loc64 = loc(callsite(#loc5 at #loc4)) +#loc65 = loc(callsite(#loc3 at #loc6)) +#loc66 = loc(callsite(#loc5 at #loc6)) +#loc67 = loc(callsite(#loc3 at #loc7)) +#loc68 = loc(callsite(#loc5 at #loc7)) diff --git a/test1.mlir b/test1.mlir new file mode 100644 index 000000000000..1c40c1375f8f --- /dev/null +++ b/test1.mlir @@ -0,0 +1,149 @@ +module { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = ub.poison : tensor<64x256xi32> + %1 = ub.poison : tensor<128x64xi32> + %2 = ub.poison : tensor<256xi32> + %3 = ub.poison : tensor<128xi32> + %4 = ub.poison : tensor<128x256xf32> + %5 = ub.poison : i32 + %c-1_i64 = arith.constant -1 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32> + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c132_i32 = arith.constant 132 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_2 = arith.constant dense<0> : tensor<256xi32> + %cst_3 = arith.constant dense<0> : tensor<128xi32> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %6 = tt.get_program_id x : i32 + %7 = arith.addi %arg3, %c127_i32 : i32 + %8 = arith.divsi %7, %c128_i32 : i32 + %9 = arith.addi %arg4, %c255_i32 : i32 + %10 = arith.divsi %9, %c256_i32 : i32 + %11 = arith.addi %arg5, %c63_i32 : i32 + %12 = arith.divsi %11, %c64_i32 : i32 + %13 = arith.muli %8, %10 : i32 + %14 = arith.muli %10, %c8_i32 : i32 + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %18 = tt.splat %arg3 : i32 -> tensor<128xi32> + %19 = tt.splat %arg4 : i32 -> tensor<256xi32> + %20 = tt.splat %arg6 : i32 -> tensor<128x1xi32> + %21 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + %22 = tt.splat %arg7 : i32 -> tensor<1x256xi32> + %23 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> + %24 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %25 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %26 = tt.splat %arg8 : i32 -> tensor<128x1xi32> + %27 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + %28 = tt.splat %arg3 : i32 -> tensor<128x1xi32> + %29 = tt.splat %arg4 : i32 -> tensor<1x256xi32> + %30 = arith.subi %13, %6 : i32 + %31 = arith.ceildivsi %30, %c132_i32 : i32 + %32 = arith.extsi %12 : i32 to i64 + %33 = arith.maxsi %32, %c1_i64 : i64 + %34 = arith.extsi %31 : i32 to i64 + %35 = arith.muli %34, %33 : i64 + %36 = arith.subi %33, %c1_i64 : i64 + %37:8 = scf.for %arg9 = %c0_i64 to %35 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %5, %arg13 = %4, %arg14 = %3, %arg15 = %2, %arg16 = %1, %arg17 = %0) -> (i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) : i64 { + %38 = arith.addi %arg10, %c1_i64 : i64 + %39 = arith.remsi %38, %33 : i64 + %40 = arith.cmpi eq, %39, %c0_i64 : i64 + %41 = arith.select %40, %c0_i32, %arg12 : i32 + %42 = arith.select %40, %cst, %arg13 : tensor<128x256xf32> + %43:4 = scf.if %40 -> (tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) { + %50 = arith.divsi %arg11, %14 : i32 + %51 = arith.muli %50, %c8_i32 : i32 + %52 = arith.subi %8, %51 : i32 + %53 = arith.minsi %52, %c8_i32 : i32 + %54 = arith.remsi %arg11, %53 : i32 + %55 = arith.addi %51, %54 : i32 + %56 = arith.remsi %arg11, %14 : i32 + %57 = arith.divsi %56, %53 : i32 + %58 = arith.muli %55, %c128_i32 : i32 + %59 = arith.muli %57, %c256_i32 : i32 + %60 = tt.splat %58 : i32 -> tensor<128xi32> + %61 = arith.addi %60, %16 : tensor<128xi32> + %62 = tt.splat %59 : i32 -> tensor<256xi32> + %63 = arith.addi %62, %17 : tensor<256xi32> + %64 = arith.cmpi slt, %61, %18 : tensor<128xi32> + %65 = arith.select %64, %61, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> + %66 = arith.cmpi slt, %63, %19 : tensor<256xi32> + %67 = arith.select %66, %63, %cst_2 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> + %68 = tt.expand_dims %65 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %69 = arith.muli %68, %20 : tensor<128x1xi32> + %70 = tt.broadcast %69 : tensor<128x1xi32> -> tensor<128x64xi32> + %71 = tt.expand_dims %67 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %72 = arith.muli %71, %22 : tensor<1x256xi32> + %73 = tt.broadcast %72 : tensor<1x256xi32> -> tensor<64x256xi32> + scf.yield %61, %63, %70, %73 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> + } + %44 = arith.cmpi sge, %39, %c0_i64 : i64 + %45 = arith.cmpi slt, %39, %32 : i64 + %46 = arith.andi %44, %45 : i1 + %47:2 = scf.if %46 -> (i32, tensor<128x256xf32>) { + %50 = arith.muli %41, %c64_i32 : i32 + %51 = tt.splat %50 : i32 -> tensor<64xi32> + %52 = arith.addi %51, %15 : tensor<64xi32> + %53 = tt.expand_dims %52 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %54 = tt.broadcast %53 : tensor<1x64xi32> -> tensor<128x64xi32> + %55 = arith.addi %43#2, %54 : tensor<128x64xi32> + %56 = tt.addptr %21, %55 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %57 = tt.expand_dims %52 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %58 = tt.broadcast %57 : tensor<64x1xi32> -> tensor<64x256xi32> + %59 = arith.addi %58, %43#3 : tensor<64x256xi32> + %60 = tt.addptr %23, %59 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %61 = arith.subi %arg5, %50 : i32 + %62 = tt.splat %61 : i32 -> tensor<1x64xi32> + %63 = arith.cmpi slt, %24, %62 : tensor<1x64xi32> + %64 = tt.broadcast %63 : tensor<1x64xi1> -> tensor<128x64xi1> + %65 = tt.load %56, %64, %cst_1 : tensor<128x64x!tt.ptr> + %66 = tt.splat %61 : i32 -> tensor<64x1xi32> + %67 = arith.cmpi slt, %25, %66 : tensor<64x1xi32> + %68 = tt.broadcast %67 : tensor<64x1xi1> -> tensor<64x256xi1> + %69 = tt.load %60, %68, %cst_0 : tensor<64x256x!tt.ptr> + %70 = tt.dot %65, %69, %42, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> + %71 = arith.addi %41, %c1_i32 : i32 + scf.yield %71, %70 : i32, tensor<128x256xf32> + } else { + scf.yield %41, %arg13 : i32, tensor<128x256xf32> + } + %48 = arith.cmpi eq, %39, %36 : i64 + %49 = scf.if %48 -> (i32) { + %50 = tt.expand_dims %43#0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %51 = arith.muli %26, %50 : tensor<128x1xi32> + %52 = tt.addptr %27, %51 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %53 = tt.expand_dims %43#1 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %54 = tt.broadcast %52 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> + %55 = tt.broadcast %53 : tensor<1x256xi32> -> tensor<128x256xi32> + %56 = tt.addptr %54, %55 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %57 = arith.cmpi slt, %50, %28 : tensor<128x1xi32> + %58 = arith.cmpi slt, %53, %29 : tensor<1x256xi32> + %59 = tt.broadcast %57 : tensor<128x1xi1> -> tensor<128x256xi1> + %60 = tt.broadcast %58 : tensor<1x256xi1> -> tensor<128x256xi1> + %61 = arith.andi %59, %60 : tensor<128x256xi1> + %62 = arith.truncf %47#1 : tensor<128x256xf32> to tensor<128x256xf16> + tt.store %56, %62, %61 : tensor<128x256x!tt.ptr> + %63 = arith.addi %arg11, %c132_i32 : i32 + scf.yield %63 : i32 + } else { + scf.yield %arg11 : i32 + } + scf.yield %39, %49, %47#0, %47#1, %43#0, %43#1, %43#2, %43#3 : i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> + } + tt.return + } +} + From 89db13caaedd5306fc05e7263c2ee2ef56b89486 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 23 Jan 2025 22:50:21 -0500 Subject: [PATCH 02/32] almost --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 71 ++++++ .../Transforms/Pipeliner/AssignLatencies.cpp | 4 + test1.mlir | 220 ++++++++++-------- test2.mlir | 166 +++++++++++++ test3.mlir | 197 ++++++++++++++++ 5 files changed, 561 insertions(+), 97 deletions(-) create mode 100644 test2.mlir create mode 100644 test3.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index ecba8c8582d7..3fb34ad5f40b 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -17,6 +17,8 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUFUSENESTEDLOOPS #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute"; + namespace { struct FuseNestedLoopsPass : public impl::TritonGPUFuseNestedLoopsBase { @@ -705,6 +707,14 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { inner.replaceAllUsesWith( bodyIf.getResults().slice(1, inner.getNumResults())); + // If the inner loop must execute, then its body does not have to be wrapped + // in a conditional. + if (inner->hasAttr(kMustExecuteAttrName)) { + b.setInsertionPoint(bodyIf); + bodyIf.getConditionMutable().assign( + b.create(loc, b.getBoolAttr(true))); + } + // Move the insertion point for the next iteration. b.setInsertionPointAfter(bodyIf); } @@ -790,6 +800,65 @@ static bool shouldFuse(const LoopNest &nest) { [](Operation &op) { return isa(op); }); } +// Speculate the length of the inner loop such that the loop is known to execute +// at least once. This way, the inner loop body does not have to be placed +// inside a conditional in the fused loop, which interacts better with the +// pipeliner. +static LogicalResult speculateInnerLoopLength(const LoopNest &nest, + mlir::DominanceInfo &domInfo) { + assert(nest.nodes.size() == 2 && nest.root->children.size() == 1); + + scf::ForOp outerLoop = nest.root->loop; + scf::ForOp innerLoop = nest.root->children.front()->loop; + + // The inner loop bounds must be outer-loop invariant to speculate from + // outside the loop nest. + Location loc = innerLoop.getLoc(); + llvm::SetVector toHoist; + if (!isOuterLoopInvariant(domInfo, outerLoop, + {innerLoop.getLowerBound(), + innerLoop.getUpperBound(), innerLoop.getStep()}, + toHoist)) + return failure(); + + // Hoist the inner loop bounds computations if necessary. + toHoist = topologicalSort(toHoist); + for (Operation *op : toHoist) + op->moveBefore(outerLoop); + + // Mark the inner loop. + OpBuilder b(outerLoop); + innerLoop->setAttr(kMustExecuteAttrName, b.getUnitAttr()); + + // Speculate on whether the length of the inner loop is zero. + Value lenInner = computeNumIters(b, innerLoop); + auto zeroAttr = IntegerAttr::get(lenInner.getType(), 0); + Value innerLoopEmpty = + b.create(loc, arith::CmpIPredicate::eq, lenInner, + b.create(loc, zeroAttr)); + auto ifOp = + b.create(loc, outerLoop.getResultTypes(), innerLoopEmpty); + + // In the `then` branch, the inner loop does not execute. Clone the loop nest + // into it and remove the inner loop. + mlir::IRMapping map; + b.createBlock(&ifOp.getThenRegion()); + auto newLoop = cast(b.clone(*outerLoop, map)); + b.create(loc, newLoop.getResults()); + auto newInnerLoop = cast(map.lookup(innerLoop)); + newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits()); + newInnerLoop.erase(); + + // Move the loop nest into the `else` branch. + outerLoop.replaceAllUsesWith(ifOp.getResults()); + Block *block = b.createBlock(&ifOp.getElseRegion()); + outerLoop->remove(); + b.insert(outerLoop); + b.create(loc, outerLoop.getResults()); + + return success(); +} + void FuseNestedLoopsPass::runOnOperation() { auto &domInfo = getAnalysis(); @@ -799,6 +868,8 @@ void FuseNestedLoopsPass::runOnOperation() { for (LoopNest &nest : nests) { if (!shouldFuse(nest)) continue; + if (failed(speculateInnerLoopLength(nest, domInfo))) + continue; flattenLoopNest(nest.root, domInfo); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp index 20d5d418fdfa..5226c8e831a4 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -125,9 +125,11 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, [&](Operation *op, Operation *finalUser, int distance) { if (!seen.insert(op).second || excluded.count(op)) return; + op->dump(); if (isa(op)) { if (!isPipeliningBeneficial(op, finalUser, axisInfoAnalysis)) return; + op->dump(); if (loadOpToIndLevel.count(op)) { int level = loadOpToIndLevel[op]; if (level != distance) { @@ -168,6 +170,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, continue; seenDot = true; seen.clear(); + op.dump(); dfs(&op, &op, 0); } @@ -229,6 +232,7 @@ DenseMap assignLatencies(ModuleOp moduleOp, DenseMap opLatency; for (auto forOp : loops) { + forOp.dump(); if (hasLatenciesAssigned(forOp)) { assignUserProvidedLatencies(forOp, opLatency); continue; diff --git a/test1.mlir b/test1.mlir index 1c40c1375f8f..691cd743c1bb 100644 --- a/test1.mlir +++ b/test1.mlir @@ -1,5 +1,6 @@ module { tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16> %0 = ub.poison : tensor<64x256xi32> %1 = ub.poison : tensor<128x64xi32> %2 = ub.poison : tensor<256xi32> @@ -9,18 +10,18 @@ module { %c-1_i64 = arith.constant -1 : i64 %c1_i64 = arith.constant 1 : i64 %c0_i64 = arith.constant 0 : i64 - %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32> %c63_i32 = arith.constant 63 : i32 %c255_i32 = arith.constant 255 : i32 %c127_i32 = arith.constant 127 : i32 - %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %c132_i32 = arith.constant 132 : i32 %c64_i32 = arith.constant 64 : i32 - %cst_2 = arith.constant dense<0> : tensor<256xi32> - %cst_3 = arith.constant dense<0> : tensor<128xi32> + %cst_3 = arith.constant dense<0> : tensor<256xi32> + %cst_4 = arith.constant dense<0> : tensor<128xi32> %c256_i32 = arith.constant 256 : i32 %c128_i32 = arith.constant 128 : i32 %c8_i32 = arith.constant 8 : i32 @@ -48,100 +49,125 @@ module { %27 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> %28 = tt.splat %arg3 : i32 -> tensor<128x1xi32> %29 = tt.splat %arg4 : i32 -> tensor<1x256xi32> - %30 = arith.subi %13, %6 : i32 - %31 = arith.ceildivsi %30, %c132_i32 : i32 - %32 = arith.extsi %12 : i32 to i64 - %33 = arith.maxsi %32, %c1_i64 : i64 - %34 = arith.extsi %31 : i32 to i64 - %35 = arith.muli %34, %33 : i64 - %36 = arith.subi %33, %c1_i64 : i64 - %37:8 = scf.for %arg9 = %c0_i64 to %35 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %5, %arg13 = %4, %arg14 = %3, %arg15 = %2, %arg16 = %1, %arg17 = %0) -> (i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) : i64 { - %38 = arith.addi %arg10, %c1_i64 : i64 - %39 = arith.remsi %38, %33 : i64 - %40 = arith.cmpi eq, %39, %c0_i64 : i64 - %41 = arith.select %40, %c0_i32, %arg12 : i32 - %42 = arith.select %40, %cst, %arg13 : tensor<128x256xf32> - %43:4 = scf.if %40 -> (tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) { - %50 = arith.divsi %arg11, %14 : i32 - %51 = arith.muli %50, %c8_i32 : i32 - %52 = arith.subi %8, %51 : i32 - %53 = arith.minsi %52, %c8_i32 : i32 - %54 = arith.remsi %arg11, %53 : i32 - %55 = arith.addi %51, %54 : i32 - %56 = arith.remsi %arg11, %14 : i32 - %57 = arith.divsi %56, %53 : i32 - %58 = arith.muli %55, %c128_i32 : i32 - %59 = arith.muli %57, %c256_i32 : i32 - %60 = tt.splat %58 : i32 -> tensor<128xi32> - %61 = arith.addi %60, %16 : tensor<128xi32> - %62 = tt.splat %59 : i32 -> tensor<256xi32> - %63 = arith.addi %62, %17 : tensor<256xi32> - %64 = arith.cmpi slt, %61, %18 : tensor<128xi32> - %65 = arith.select %64, %61, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> - %66 = arith.cmpi slt, %63, %19 : tensor<256xi32> - %67 = arith.select %66, %63, %cst_2 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> - %68 = tt.expand_dims %65 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %69 = arith.muli %68, %20 : tensor<128x1xi32> - %70 = tt.broadcast %69 : tensor<128x1xi32> -> tensor<128x64xi32> - %71 = tt.expand_dims %67 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %72 = arith.muli %71, %22 : tensor<1x256xi32> - %73 = tt.broadcast %72 : tensor<1x256xi32> -> tensor<64x256xi32> - scf.yield %61, %63, %70, %73 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> - } else { - scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> + %30 = arith.cmpi eq, %12, %c0_i32 : i32 + scf.if %30 { + scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { + %31 = arith.divsi %arg9, %14 : i32 + %32 = arith.muli %31, %c8_i32 : i32 + %33 = arith.subi %8, %32 : i32 + %34 = arith.minsi %33, %c8_i32 : i32 + %35 = arith.remsi %arg9, %34 : i32 + %36 = arith.addi %32, %35 : i32 + %37 = arith.remsi %arg9, %14 : i32 + %38 = arith.divsi %37, %34 : i32 + %39 = arith.muli %36, %c128_i32 : i32 + %40 = arith.muli %38, %c256_i32 : i32 + %41 = tt.splat %39 : i32 -> tensor<128xi32> + %42 = arith.addi %41, %16 : tensor<128xi32> + %43 = tt.splat %40 : i32 -> tensor<256xi32> + %44 = arith.addi %43, %17 : tensor<256xi32> + %45 = tt.expand_dims %42 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %46 = arith.muli %26, %45 : tensor<128x1xi32> + %47 = tt.addptr %27, %46 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %48 = tt.expand_dims %44 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %49 = tt.broadcast %47 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> + %50 = tt.broadcast %48 : tensor<1x256xi32> -> tensor<128x256xi32> + %51 = tt.addptr %49, %50 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %52 = arith.cmpi slt, %45, %28 : tensor<128x1xi32> + %53 = arith.cmpi slt, %48, %29 : tensor<1x256xi32> + %54 = tt.broadcast %52 : tensor<128x1xi1> -> tensor<128x256xi1> + %55 = tt.broadcast %53 : tensor<1x256xi1> -> tensor<128x256xi1> + %56 = arith.andi %54, %55 : tensor<128x256xi1> + tt.store %51, %cst, %56 : tensor<128x256x!tt.ptr> } - %44 = arith.cmpi sge, %39, %c0_i64 : i64 - %45 = arith.cmpi slt, %39, %32 : i64 - %46 = arith.andi %44, %45 : i1 - %47:2 = scf.if %46 -> (i32, tensor<128x256xf32>) { - %50 = arith.muli %41, %c64_i32 : i32 - %51 = tt.splat %50 : i32 -> tensor<64xi32> - %52 = arith.addi %51, %15 : tensor<64xi32> - %53 = tt.expand_dims %52 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %54 = tt.broadcast %53 : tensor<1x64xi32> -> tensor<128x64xi32> - %55 = arith.addi %43#2, %54 : tensor<128x64xi32> - %56 = tt.addptr %21, %55 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %57 = tt.expand_dims %52 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> - %58 = tt.broadcast %57 : tensor<64x1xi32> -> tensor<64x256xi32> - %59 = arith.addi %58, %43#3 : tensor<64x256xi32> - %60 = tt.addptr %23, %59 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> - %61 = arith.subi %arg5, %50 : i32 - %62 = tt.splat %61 : i32 -> tensor<1x64xi32> - %63 = arith.cmpi slt, %24, %62 : tensor<1x64xi32> - %64 = tt.broadcast %63 : tensor<1x64xi1> -> tensor<128x64xi1> - %65 = tt.load %56, %64, %cst_1 : tensor<128x64x!tt.ptr> - %66 = tt.splat %61 : i32 -> tensor<64x1xi32> - %67 = arith.cmpi slt, %25, %66 : tensor<64x1xi32> - %68 = tt.broadcast %67 : tensor<64x1xi1> -> tensor<64x256xi1> - %69 = tt.load %60, %68, %cst_0 : tensor<64x256x!tt.ptr> - %70 = tt.dot %65, %69, %42, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> - %71 = arith.addi %41, %c1_i32 : i32 - scf.yield %71, %70 : i32, tensor<128x256xf32> - } else { - scf.yield %41, %arg13 : i32, tensor<128x256xf32> + } else { + %31 = arith.subi %13, %6 : i32 + %32 = arith.ceildivsi %31, %c132_i32 : i32 + %33 = arith.extsi %12 : i32 to i64 + %34 = arith.maxsi %33, %c1_i64 : i64 + %35 = arith.extsi %32 : i32 to i64 + %36 = arith.muli %35, %34 : i64 + %37 = arith.subi %34, %c1_i64 : i64 + %38:8 = scf.for %arg9 = %c0_i64 to %36 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %5, %arg13 = %4, %arg14 = %3, %arg15 = %2, %arg16 = %1, %arg17 = %0) -> (i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) : i64 { + %39 = arith.addi %arg10, %c1_i64 : i64 + %40 = arith.remsi %39, %34 : i64 + %41 = arith.cmpi eq, %40, %c0_i64 : i64 + %42 = arith.select %41, %c0_i32, %arg12 : i32 + %43 = arith.select %41, %cst_0, %arg13 : tensor<128x256xf32> + %44:4 = scf.if %41 -> (tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) { + %69 = arith.divsi %arg11, %14 : i32 + %70 = arith.muli %69, %c8_i32 : i32 + %71 = arith.subi %8, %70 : i32 + %72 = arith.minsi %71, %c8_i32 : i32 + %73 = arith.remsi %arg11, %72 : i32 + %74 = arith.addi %70, %73 : i32 + %75 = arith.remsi %arg11, %14 : i32 + %76 = arith.divsi %75, %72 : i32 + %77 = arith.muli %74, %c128_i32 : i32 + %78 = arith.muli %76, %c256_i32 : i32 + %79 = tt.splat %77 : i32 -> tensor<128xi32> + %80 = arith.addi %79, %16 : tensor<128xi32> + %81 = tt.splat %78 : i32 -> tensor<256xi32> + %82 = arith.addi %81, %17 : tensor<256xi32> + %83 = arith.cmpi slt, %80, %18 : tensor<128xi32> + %84 = arith.select %83, %80, %cst_4 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> + %85 = arith.cmpi slt, %82, %19 : tensor<256xi32> + %86 = arith.select %85, %82, %cst_3 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> + %87 = tt.expand_dims %84 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %88 = arith.muli %87, %20 : tensor<128x1xi32> + %89 = tt.broadcast %88 : tensor<128x1xi32> -> tensor<128x64xi32> + %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %91 = arith.muli %90, %22 : tensor<1x256xi32> + %92 = tt.broadcast %91 : tensor<1x256xi32> -> tensor<64x256xi32> + scf.yield %80, %82, %89, %92 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> + } + %45 = arith.muli %42, %c64_i32 : i32 + %46 = tt.splat %45 : i32 -> tensor<64xi32> + %47 = arith.addi %46, %15 : tensor<64xi32> + %48 = tt.expand_dims %47 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %49 = tt.broadcast %48 : tensor<1x64xi32> -> tensor<128x64xi32> + %50 = arith.addi %44#2, %49 : tensor<128x64xi32> + %51 = tt.addptr %21, %50 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %52 = tt.expand_dims %47 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %53 = tt.broadcast %52 : tensor<64x1xi32> -> tensor<64x256xi32> + %54 = arith.addi %53, %44#3 : tensor<64x256xi32> + %55 = tt.addptr %23, %54 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %56 = arith.subi %arg5, %45 : i32 + %57 = tt.splat %56 : i32 -> tensor<1x64xi32> + %58 = arith.cmpi slt, %24, %57 : tensor<1x64xi32> + %59 = tt.broadcast %58 : tensor<1x64xi1> -> tensor<128x64xi1> + %60 = tt.load %51, %59, %cst_2 : tensor<128x64x!tt.ptr> + %61 = tt.splat %56 : i32 -> tensor<64x1xi32> + %62 = arith.cmpi slt, %25, %61 : tensor<64x1xi32> + %63 = tt.broadcast %62 : tensor<64x1xi1> -> tensor<64x256xi1> + %64 = tt.load %55, %63, %cst_1 : tensor<64x256x!tt.ptr> + %65 = tt.dot %60, %64, %43, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> + %66 = arith.addi %42, %c1_i32 : i32 + %67 = arith.cmpi eq, %40, %37 : i64 + %68 = scf.if %67 -> (i32) { + %69 = tt.expand_dims %44#0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %70 = arith.muli %26, %69 : tensor<128x1xi32> + %71 = tt.addptr %27, %70 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %72 = tt.expand_dims %44#1 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %73 = tt.broadcast %71 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> + %74 = tt.broadcast %72 : tensor<1x256xi32> -> tensor<128x256xi32> + %75 = tt.addptr %73, %74 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %76 = arith.cmpi slt, %69, %28 : tensor<128x1xi32> + %77 = arith.cmpi slt, %72, %29 : tensor<1x256xi32> + %78 = tt.broadcast %76 : tensor<128x1xi1> -> tensor<128x256xi1> + %79 = tt.broadcast %77 : tensor<1x256xi1> -> tensor<128x256xi1> + %80 = arith.andi %78, %79 : tensor<128x256xi1> + %81 = arith.truncf %65 : tensor<128x256xf32> to tensor<128x256xf16> + tt.store %75, %81, %80 : tensor<128x256x!tt.ptr> + %82 = arith.addi %arg11, %c132_i32 : i32 + scf.yield %82 : i32 + } else { + scf.yield %arg11 : i32 + } + scf.yield %40, %68, %66, %65, %44#0, %44#1, %44#2, %44#3 : i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> } - %48 = arith.cmpi eq, %39, %36 : i64 - %49 = scf.if %48 -> (i32) { - %50 = tt.expand_dims %43#0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %51 = arith.muli %26, %50 : tensor<128x1xi32> - %52 = tt.addptr %27, %51 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> - %53 = tt.expand_dims %43#1 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %54 = tt.broadcast %52 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> - %55 = tt.broadcast %53 : tensor<1x256xi32> -> tensor<128x256xi32> - %56 = tt.addptr %54, %55 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %57 = arith.cmpi slt, %50, %28 : tensor<128x1xi32> - %58 = arith.cmpi slt, %53, %29 : tensor<1x256xi32> - %59 = tt.broadcast %57 : tensor<128x1xi1> -> tensor<128x256xi1> - %60 = tt.broadcast %58 : tensor<1x256xi1> -> tensor<128x256xi1> - %61 = arith.andi %59, %60 : tensor<128x256xi1> - %62 = arith.truncf %47#1 : tensor<128x256xf32> to tensor<128x256xf16> - tt.store %56, %62, %61 : tensor<128x256x!tt.ptr> - %63 = arith.addi %arg11, %c132_i32 : i32 - scf.yield %63 : i32 - } else { - scf.yield %arg11 : i32 - } - scf.yield %39, %49, %47#0, %47#1, %43#0, %43#1, %43#2, %43#3 : i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> } tt.return } diff --git a/test2.mlir b/test2.mlir new file mode 100644 index 000000000000..6faf32260e19 --- /dev/null +++ b/test2.mlir @@ -0,0 +1,166 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c-1_i64 = arith.constant -1 : i64 + %0 = ub.poison : i32 + %1 = ub.poison : tensor<128x256xf32, #mma> + %2 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %4 = ub.poison : tensor<128x64xi32, #blocked1> + %5 = ub.poison : tensor<64x256xi32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked2> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %6 = tt.get_program_id x : i32 + %7 = arith.addi %arg3, %c127_i32 : i32 + %8 = arith.divsi %7, %c128_i32 : i32 + %9 = arith.addi %arg4, %c255_i32 : i32 + %10 = arith.divsi %9, %c256_i32 : i32 + %11 = arith.addi %arg5, %c63_i32 : i32 + %12 = arith.divsi %11, %c64_i32 : i32 + %13 = arith.muli %8, %10 : i32 + %14 = arith.muli %10, %c8_i32 : i32 + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %23 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %25 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %26 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %27 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %28 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %29 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %30 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %31 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> + %32 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %33 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %34 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %35 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1> + %36 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %37 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> + %38 = arith.cmpi eq, %12, %c0_i32 : i32 + %39 = arith.subi %13, %6 : i32 + %40 = arith.ceildivsi %39, %c132_i32 : i32 + %41 = arith.extsi %12 : i32 to i64 + %42 = arith.maxsi %41, %c1_i64 : i64 + %43 = arith.extsi %40 : i32 to i64 + %44 = arith.muli %43, %42 : i64 + %45 = arith.subi %42, %c1_i64 : i64 + %true = arith.constant true + %false = arith.constant false + %46:9 = scf.for %arg9 = %c0_i64 to %44 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %0, %arg13 = %1, %arg14 = %4, %arg15 = %5, %arg16 = %3, %arg17 = %2, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1) : i64 { + %47 = arith.addi %arg10, %c1_i64 : i64 + %48 = arith.remsi %47, %42 : i64 + %49 = arith.cmpi eq, %48, %c0_i64 : i64 + %50 = arith.select %49, %c0_i32, %arg12 : i32 + %51 = arith.select %49, %false, %arg18 : i1 + %52:4 = scf.if %49 -> (tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>) { + %81 = arith.divsi %arg11, %14 : i32 + %82 = arith.muli %81, %c8_i32 : i32 + %83 = arith.subi %8, %82 : i32 + %84 = arith.minsi %83, %c8_i32 : i32 + %85 = arith.remsi %arg11, %84 : i32 + %86 = arith.addi %82, %85 : i32 + %87 = arith.remsi %arg11, %14 : i32 + %88 = arith.divsi %87, %84 : i32 + %89 = arith.muli %86, %c128_i32 : i32 + %90 = arith.muli %88, %c256_i32 : i32 + %91 = tt.splat %89 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %92 = arith.addi %91, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %93 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %94 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %95 = arith.addi %93, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %96 = arith.addi %94, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %97 = arith.cmpi slt, %92, %22 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %98 = arith.select %97, %92, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %99 = arith.cmpi slt, %95, %23 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %100 = arith.select %99, %95, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %101 = tt.expand_dims %98 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %102 = arith.muli %101, %24 : tensor<128x1xi32, #blocked1> + %103 = tt.broadcast %102 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %104 = tt.expand_dims %100 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %105 = arith.muli %104, %26 : tensor<1x256xi32, #blocked> + %106 = tt.broadcast %105 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + scf.yield %103, %106, %96, %92 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + } + %53 = arith.muli %50, %c64_i32 : i32 + %54 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %55 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %56 = arith.addi %54, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %57 = arith.addi %55, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %58 = tt.expand_dims %56 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %59 = tt.broadcast %58 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %60 = arith.addi %52#0, %59 : tensor<128x64xi32, #blocked1> + %61 = tt.addptr %25, %60 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %62 = tt.expand_dims %57 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %63 = tt.broadcast %62 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %64 = arith.addi %63, %52#1 : tensor<64x256xi32, #blocked> + %65 = tt.addptr %27, %64 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %66 = arith.subi %arg5, %53 : i32 + %67 = tt.splat %66 : i32 -> tensor<1x64xi32, #blocked1> + %68 = arith.cmpi slt, %28, %67 : tensor<1x64xi32, #blocked1> + %69 = tt.broadcast %68 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %70 = tt.load %61, %69, %cst_2 : tensor<128x64x!tt.ptr, #blocked1> + %71 = ttg.local_alloc %70 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %72 = tt.splat %66 : i32 -> tensor<64x1xi32, #blocked> + %73 = arith.cmpi slt, %29, %72 : tensor<64x1xi32, #blocked> + %74 = tt.broadcast %73 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %75 = tt.load %65, %74, %cst_3 : tensor<64x256x!tt.ptr, #blocked> + %76 = ttg.local_alloc %75 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared, #smem> + %77 = ttng.warp_group_dot %71, %76, %arg13, %51 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> + %78 = arith.addi %50, %c1_i32 : i32 + %79 = arith.cmpi eq, %48, %45 : i64 + %80 = scf.if %79 -> (i32) { + %81 = tt.expand_dims %52#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %82 = arith.muli %31, %81 : tensor<128x1xi32, #blocked1> + %83 = tt.addptr %33, %82 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %84 = tt.expand_dims %52#2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %85 = tt.broadcast %83 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> + %86 = tt.broadcast %84 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %87 = tt.addptr %85, %86 : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %88 = arith.cmpi slt, %81, %35 : tensor<128x1xi32, #blocked1> + %89 = arith.cmpi slt, %84, %37 : tensor<1x256xi32, #blocked1> + %90 = tt.broadcast %88 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %91 = tt.broadcast %89 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %92 = arith.andi %90, %91 : tensor<128x256xi1, #blocked1> + %93 = arith.truncf %77 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %94 = ttg.convert_layout %93 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.store %87, %94, %92 : tensor<128x256x!tt.ptr, #blocked1> + %95 = arith.addi %arg11, %c132_i32 : i32 + scf.yield %95 : i32 + } else { + scf.yield %arg11 : i32 + } + scf.yield %48, %80, %78, %77, %52#0, %52#1, %52#2, %52#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1 + } + tt.return + } +} + diff --git a/test3.mlir b/test3.mlir new file mode 100644 index 000000000000..9b84523f6053 --- /dev/null +++ b/test3.mlir @@ -0,0 +1,197 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %false = arith.constant false + %true = arith.constant true + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c-1_i64 = arith.constant -1 : i64 + %0 = ub.poison : i32 + %1 = ub.poison : tensor<128x256xf32, #mma> + %2 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %4 = ub.poison : tensor<128x64xi32, #blocked1> + %5 = ub.poison : tensor<64x256xi32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked2> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %6 = tt.get_program_id x : i32 + %7 = arith.addi %arg3, %c127_i32 : i32 + %8 = arith.divsi %7, %c128_i32 : i32 + %9 = arith.addi %arg4, %c255_i32 : i32 + %10 = arith.divsi %9, %c256_i32 : i32 + %11 = arith.addi %arg5, %c63_i32 : i32 + %12 = arith.divsi %11, %c64_i32 : i32 + %13 = arith.muli %8, %10 : i32 + %14 = arith.muli %10, %c8_i32 : i32 + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %23 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %25 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %26 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %27 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %28 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %29 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %30 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %31 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> + %32 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %33 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %34 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %35 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1> + %36 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %37 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> + %38 = arith.cmpi eq, %12, %c0_i32 : i32 + scf.if %38 { + scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { + %39 = arith.divsi %arg9, %14 : i32 + %40 = arith.muli %39, %c8_i32 : i32 + %41 = arith.subi %8, %40 : i32 + %42 = arith.minsi %41, %c8_i32 : i32 + %43 = arith.remsi %arg9, %42 : i32 + %44 = arith.addi %40, %43 : i32 + %45 = arith.remsi %arg9, %14 : i32 + %46 = arith.divsi %45, %42 : i32 + %47 = arith.muli %44, %c128_i32 : i32 + %48 = arith.muli %46, %c256_i32 : i32 + %49 = tt.splat %47 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %50 = arith.addi %49, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %51 = tt.splat %48 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %52 = arith.addi %51, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %53 = tt.expand_dims %50 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %54 = arith.muli %30, %53 : tensor<128x1xi32, #blocked2> + %55 = tt.addptr %32, %54 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %56 = tt.expand_dims %52 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %57 = tt.broadcast %55 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %58 = tt.broadcast %56 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %59 = tt.addptr %57, %58 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %60 = arith.cmpi slt, %53, %34 : tensor<128x1xi32, #blocked2> + %61 = arith.cmpi slt, %56, %36 : tensor<1x256xi32, #blocked2> + %62 = tt.broadcast %60 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %63 = tt.broadcast %61 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %64 = arith.andi %62, %63 : tensor<128x256xi1, #blocked2> + tt.store %59, %cst_1, %64 : tensor<128x256x!tt.ptr, #blocked2> + } + } else { + %39 = arith.subi %13, %6 : i32 + %40 = arith.ceildivsi %39, %c132_i32 : i32 + %41 = arith.extsi %12 : i32 to i64 + %42 = arith.maxsi %41, %c1_i64 : i64 + %43 = arith.extsi %40 : i32 to i64 + %44 = arith.muli %43, %42 : i64 + %45 = arith.subi %42, %c1_i64 : i64 + %46:9 = scf.for %arg9 = %c0_i64 to %44 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %0, %arg13 = %1, %arg14 = %4, %arg15 = %5, %arg16 = %3, %arg17 = %2, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1) : i64 { + %47 = arith.addi %arg10, %c1_i64 : i64 + %48 = arith.remsi %47, %42 : i64 + %49 = arith.cmpi eq, %48, %c0_i64 : i64 + %50 = arith.select %49, %c0_i32, %arg12 : i32 + %51 = arith.select %49, %false, %arg18 : i1 + %52:4 = scf.if %49 -> (tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>) { + %81 = arith.divsi %arg11, %14 : i32 + %82 = arith.muli %81, %c8_i32 : i32 + %83 = arith.subi %8, %82 : i32 + %84 = arith.minsi %83, %c8_i32 : i32 + %85 = arith.remsi %arg11, %84 : i32 + %86 = arith.addi %82, %85 : i32 + %87 = arith.remsi %arg11, %14 : i32 + %88 = arith.divsi %87, %84 : i32 + %89 = arith.muli %86, %c128_i32 : i32 + %90 = arith.muli %88, %c256_i32 : i32 + %91 = tt.splat %89 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %92 = arith.addi %91, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %93 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %94 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %95 = arith.addi %93, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %96 = arith.addi %94, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %97 = arith.cmpi slt, %92, %22 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %98 = arith.select %97, %92, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %99 = arith.cmpi slt, %95, %23 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %100 = arith.select %99, %95, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %101 = tt.expand_dims %98 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %102 = arith.muli %101, %24 : tensor<128x1xi32, #blocked1> + %103 = tt.broadcast %102 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %104 = tt.expand_dims %100 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %105 = arith.muli %104, %26 : tensor<1x256xi32, #blocked> + %106 = tt.broadcast %105 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + scf.yield %103, %106, %96, %92 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + } + %53 = arith.muli %50, %c64_i32 : i32 + %54 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %55 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %56 = arith.addi %54, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %57 = arith.addi %55, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %58 = tt.expand_dims %56 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %59 = tt.broadcast %58 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %60 = arith.addi %52#0, %59 : tensor<128x64xi32, #blocked1> + %61 = tt.addptr %25, %60 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %62 = tt.expand_dims %57 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %63 = tt.broadcast %62 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %64 = arith.addi %63, %52#1 : tensor<64x256xi32, #blocked> + %65 = tt.addptr %27, %64 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %66 = arith.subi %arg5, %53 : i32 + %67 = tt.splat %66 : i32 -> tensor<1x64xi32, #blocked1> + %68 = arith.cmpi slt, %28, %67 : tensor<1x64xi32, #blocked1> + %69 = tt.broadcast %68 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %70 = tt.load %61, %69, %cst_2 : tensor<128x64x!tt.ptr, #blocked1> + %71 = ttg.local_alloc %70 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %72 = tt.splat %66 : i32 -> tensor<64x1xi32, #blocked> + %73 = arith.cmpi slt, %29, %72 : tensor<64x1xi32, #blocked> + %74 = tt.broadcast %73 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %75 = tt.load %65, %74, %cst_3 : tensor<64x256x!tt.ptr, #blocked> + %76 = ttg.local_alloc %75 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared, #smem> + %77 = ttng.warp_group_dot %71, %76, %arg13, %51 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> + %78 = arith.addi %50, %c1_i32 : i32 + %79 = arith.cmpi eq, %48, %45 : i64 + %80 = scf.if %79 -> (i32) { + %81 = tt.expand_dims %52#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %82 = arith.muli %31, %81 : tensor<128x1xi32, #blocked1> + %83 = tt.addptr %33, %82 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %84 = tt.expand_dims %52#2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %85 = tt.broadcast %83 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> + %86 = tt.broadcast %84 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %87 = tt.addptr %85, %86 : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %88 = arith.cmpi slt, %81, %35 : tensor<128x1xi32, #blocked1> + %89 = arith.cmpi slt, %84, %37 : tensor<1x256xi32, #blocked1> + %90 = tt.broadcast %88 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %91 = tt.broadcast %89 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %92 = arith.andi %90, %91 : tensor<128x256xi1, #blocked1> + %93 = arith.truncf %77 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %94 = ttg.convert_layout %93 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.store %87, %94, %92 : tensor<128x256x!tt.ptr, #blocked1> + %95 = arith.addi %arg11, %c132_i32 : i32 + scf.yield %95 : i32 + } else { + scf.yield %arg11 : i32 + } + scf.yield %48, %80, %78, %77, %52#0, %52#1, %52#2, %52#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1 + } + } + tt.return + } +} + From 03655835bd775958ceac9097e04ac0185c4abdea Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 23 Jan 2025 23:34:45 -0500 Subject: [PATCH 03/32] axisinfo for poison op --- lib/Analysis/AxisInfo.cpp | 24 ++++ test2.mlir | 281 +++++++++++++++++++++----------------- 2 files changed, 183 insertions(+), 122 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 004275c1cfb3..cea14bec9735 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1,5 +1,6 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -270,6 +271,28 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl { } }; +class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(ub::PoisonOp op, + ArrayRef *> operands) override { + constexpr int64_t largePowerOf2 = int64_t(1) << 32; + // Poison values are never accessed, thus assume optimistic values. + if (auto shape = dyn_cast(op.getType())) { + unsigned rank = shape.getRank(); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(rank, 1), + /*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2), + /*constancy=*/AxisInfo::DimVectorT(shape.getShape())); + } + + return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2}, + /*constancy=*/{1}); + } +}; + template class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: @@ -1018,6 +1041,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) visitors.append(); visitors.append, ConstantOpAxisInfoVisitor>(); + visitors.append(); visitors.append, AddSubOpAxisInfoVisitor, AddSubOpAxisInfoVisitor, diff --git a/test2.mlir b/test2.mlir index 6faf32260e19..9f3a4304fd97 100644 --- a/test2.mlir +++ b/test2.mlir @@ -1,8 +1,10 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { @@ -16,11 +18,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c-1_i64 = arith.constant -1 : i64 %0 = ub.poison : i32 %1 = ub.poison : tensor<128x256xf32, #mma> - %2 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %3 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %3 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> %4 = ub.poison : tensor<128x64xi32, #blocked1> %5 = ub.poison : tensor<64x256xi32, #blocked> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked2> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked3> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> %c64_i32 = arith.constant 64 : i32 @@ -42,123 +44,158 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %14 = arith.muli %10, %c8_i32 : i32 %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %22 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %23 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %24 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %25 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %26 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %27 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %28 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %29 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %30 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %31 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> - %32 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %33 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %34 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %35 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1> - %36 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %37 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> - %38 = arith.cmpi eq, %12, %c0_i32 : i32 - %39 = arith.subi %13, %6 : i32 - %40 = arith.ceildivsi %39, %c132_i32 : i32 - %41 = arith.extsi %12 : i32 to i64 - %42 = arith.maxsi %41, %c1_i64 : i64 - %43 = arith.extsi %40 : i32 to i64 - %44 = arith.muli %43, %42 : i64 - %45 = arith.subi %42, %c1_i64 : i64 - %true = arith.constant true - %false = arith.constant false - %46:9 = scf.for %arg9 = %c0_i64 to %44 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %0, %arg13 = %1, %arg14 = %4, %arg15 = %5, %arg16 = %3, %arg17 = %2, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1) : i64 { - %47 = arith.addi %arg10, %c1_i64 : i64 - %48 = arith.remsi %47, %42 : i64 - %49 = arith.cmpi eq, %48, %c0_i64 : i64 - %50 = arith.select %49, %c0_i32, %arg12 : i32 - %51 = arith.select %49, %false, %arg18 : i1 - %52:4 = scf.if %49 -> (tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>) { - %81 = arith.divsi %arg11, %14 : i32 - %82 = arith.muli %81, %c8_i32 : i32 - %83 = arith.subi %8, %82 : i32 - %84 = arith.minsi %83, %c8_i32 : i32 - %85 = arith.remsi %arg11, %84 : i32 - %86 = arith.addi %82, %85 : i32 - %87 = arith.remsi %arg11, %14 : i32 - %88 = arith.divsi %87, %84 : i32 - %89 = arith.muli %86, %c128_i32 : i32 - %90 = arith.muli %88, %c256_i32 : i32 - %91 = tt.splat %89 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %92 = arith.addi %91, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %93 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %94 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %95 = arith.addi %93, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %96 = arith.addi %94, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %97 = arith.cmpi slt, %92, %22 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %98 = arith.select %97, %92, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %99 = arith.cmpi slt, %95, %23 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %100 = arith.select %99, %95, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %101 = tt.expand_dims %98 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %102 = arith.muli %101, %24 : tensor<128x1xi32, #blocked1> - %103 = tt.broadcast %102 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %104 = tt.expand_dims %100 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %105 = arith.muli %104, %26 : tensor<1x256xi32, #blocked> - %106 = tt.broadcast %105 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - scf.yield %103, %106, %96, %92 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - } else { - scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %25 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %26 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %27 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %28 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %29 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %30 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %31 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked3> + %32 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %33 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked3> + %34 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %35 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked3> + %36 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %37 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked3> + %38 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %39 = arith.cmpi eq, %12, %c0_i32 : i32 + scf.if %39 { + scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { + %40 = arith.divsi %arg9, %14 : i32 + %41 = arith.muli %40, %c8_i32 : i32 + %42 = arith.subi %8, %41 : i32 + %43 = arith.minsi %42, %c8_i32 : i32 + %44 = arith.remsi %arg9, %43 : i32 + %45 = arith.addi %41, %44 : i32 + %46 = arith.remsi %arg9, %14 : i32 + %47 = arith.divsi %46, %43 : i32 + %48 = arith.muli %45, %c128_i32 : i32 + %49 = arith.muli %47, %c256_i32 : i32 + %50 = tt.splat %48 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %51 = arith.addi %50, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %52 = tt.splat %49 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %53 = arith.addi %52, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %54 = tt.expand_dims %51 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1xi32, #blocked3> + %55 = arith.muli %31, %54 : tensor<128x1xi32, #blocked3> + %56 = tt.addptr %33, %55 : tensor<128x1x!tt.ptr, #blocked3>, tensor<128x1xi32, #blocked3> + %57 = tt.expand_dims %53 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x256xi32, #blocked3> + %58 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked3> -> tensor<128x256x!tt.ptr, #blocked3> + %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked3> -> tensor<128x256xi32, #blocked3> + %60 = tt.addptr %58, %59 : tensor<128x256x!tt.ptr, #blocked3>, tensor<128x256xi32, #blocked3> + %61 = arith.cmpi slt, %54, %35 : tensor<128x1xi32, #blocked3> + %62 = arith.cmpi slt, %57, %37 : tensor<1x256xi32, #blocked3> + %63 = tt.broadcast %61 : tensor<128x1xi1, #blocked3> -> tensor<128x256xi1, #blocked3> + %64 = tt.broadcast %62 : tensor<1x256xi1, #blocked3> -> tensor<128x256xi1, #blocked3> + %65 = arith.andi %63, %64 : tensor<128x256xi1, #blocked3> + tt.store %60, %cst_1, %65 : tensor<128x256x!tt.ptr, #blocked3> } - %53 = arith.muli %50, %c64_i32 : i32 - %54 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %55 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %56 = arith.addi %54, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %57 = arith.addi %55, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %58 = tt.expand_dims %56 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %59 = tt.broadcast %58 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %60 = arith.addi %52#0, %59 : tensor<128x64xi32, #blocked1> - %61 = tt.addptr %25, %60 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %62 = tt.expand_dims %57 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %63 = tt.broadcast %62 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %64 = arith.addi %63, %52#1 : tensor<64x256xi32, #blocked> - %65 = tt.addptr %27, %64 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %66 = arith.subi %arg5, %53 : i32 - %67 = tt.splat %66 : i32 -> tensor<1x64xi32, #blocked1> - %68 = arith.cmpi slt, %28, %67 : tensor<1x64xi32, #blocked1> - %69 = tt.broadcast %68 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %70 = tt.load %61, %69, %cst_2 : tensor<128x64x!tt.ptr, #blocked1> - %71 = ttg.local_alloc %70 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %72 = tt.splat %66 : i32 -> tensor<64x1xi32, #blocked> - %73 = arith.cmpi slt, %29, %72 : tensor<64x1xi32, #blocked> - %74 = tt.broadcast %73 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %75 = tt.load %65, %74, %cst_3 : tensor<64x256x!tt.ptr, #blocked> - %76 = ttg.local_alloc %75 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared, #smem> - %77 = ttng.warp_group_dot %71, %76, %arg13, %51 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> - %78 = arith.addi %50, %c1_i32 : i32 - %79 = arith.cmpi eq, %48, %45 : i64 - %80 = scf.if %79 -> (i32) { - %81 = tt.expand_dims %52#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %82 = arith.muli %31, %81 : tensor<128x1xi32, #blocked1> - %83 = tt.addptr %33, %82 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %84 = tt.expand_dims %52#2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> - %85 = tt.broadcast %83 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> - %86 = tt.broadcast %84 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> - %87 = tt.addptr %85, %86 : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> - %88 = arith.cmpi slt, %81, %35 : tensor<128x1xi32, #blocked1> - %89 = arith.cmpi slt, %84, %37 : tensor<1x256xi32, #blocked1> - %90 = tt.broadcast %88 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> - %91 = tt.broadcast %89 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> - %92 = arith.andi %90, %91 : tensor<128x256xi1, #blocked1> - %93 = arith.truncf %77 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %94 = ttg.convert_layout %93 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> - tt.store %87, %94, %92 : tensor<128x256x!tt.ptr, #blocked1> - %95 = arith.addi %arg11, %c132_i32 : i32 - scf.yield %95 : i32 - } else { - scf.yield %arg11 : i32 + } else { + %40 = arith.subi %13, %6 : i32 + %41 = arith.ceildivsi %40, %c132_i32 : i32 + %42 = arith.extsi %12 : i32 to i64 + %43 = arith.maxsi %42, %c1_i64 : i64 + %44 = arith.extsi %41 : i32 to i64 + %45 = arith.muli %44, %43 : i64 + %46 = arith.subi %43, %c1_i64 : i64 + %true = arith.constant true + %false = arith.constant false + %47:9 = scf.for %arg9 = %c0_i64 to %45 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %0, %arg13 = %1, %arg14 = %4, %arg15 = %5, %arg16 = %3, %arg17 = %2, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>, i1) : i64 { + %48 = arith.addi %arg10, %c1_i64 : i64 + %49 = arith.remsi %48, %43 : i64 + %50 = arith.cmpi eq, %49, %c0_i64 : i64 + %51 = arith.select %50, %c0_i32, %arg12 : i32 + %52 = arith.select %50, %false, %arg18 : i1 + %53:4 = scf.if %50 -> (tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>) { + %82 = arith.divsi %arg11, %14 : i32 + %83 = arith.muli %82, %c8_i32 : i32 + %84 = arith.subi %8, %83 : i32 + %85 = arith.minsi %84, %c8_i32 : i32 + %86 = arith.remsi %arg11, %85 : i32 + %87 = arith.addi %83, %86 : i32 + %88 = arith.remsi %arg11, %14 : i32 + %89 = arith.divsi %88, %85 : i32 + %90 = arith.muli %87, %c128_i32 : i32 + %91 = arith.muli %89, %c256_i32 : i32 + %92 = tt.splat %90 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %93 = tt.splat %90 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %94 = arith.addi %92, %19 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %95 = arith.addi %93, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %96 = tt.splat %91 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %97 = tt.splat %91 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %98 = arith.addi %96, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %99 = arith.addi %97, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %100 = arith.cmpi slt, %94, %23 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %101 = arith.select %100, %94, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %102 = arith.cmpi slt, %98, %24 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %103 = arith.select %102, %98, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %104 = tt.expand_dims %101 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %105 = arith.muli %104, %25 : tensor<128x1xi32, #blocked1> + %106 = tt.broadcast %105 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %107 = tt.expand_dims %103 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %108 = arith.muli %107, %27 : tensor<1x256xi32, #blocked> + %109 = tt.broadcast %108 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + scf.yield %106, %109, %99, %95 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + } else { + scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + } + %54 = arith.muli %51, %c64_i32 : i32 + %55 = tt.splat %54 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %56 = tt.splat %54 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %57 = arith.addi %55, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %58 = arith.addi %56, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %59 = tt.expand_dims %57 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %60 = tt.broadcast %59 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %61 = arith.addi %53#0, %60 : tensor<128x64xi32, #blocked1> + %62 = tt.addptr %26, %61 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %63 = tt.expand_dims %58 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %64 = tt.broadcast %63 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %65 = arith.addi %64, %53#1 : tensor<64x256xi32, #blocked> + %66 = tt.addptr %28, %65 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %67 = arith.subi %arg5, %54 : i32 + %68 = tt.splat %67 : i32 -> tensor<1x64xi32, #blocked1> + %69 = arith.cmpi slt, %29, %68 : tensor<1x64xi32, #blocked1> + %70 = tt.broadcast %69 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %71 = tt.load %62, %70, %cst_2 : tensor<128x64x!tt.ptr, #blocked1> + %72 = ttg.local_alloc %71 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %73 = tt.splat %67 : i32 -> tensor<64x1xi32, #blocked> + %74 = arith.cmpi slt, %30, %73 : tensor<64x1xi32, #blocked> + %75 = tt.broadcast %74 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %76 = tt.load %66, %75, %cst_3 : tensor<64x256x!tt.ptr, #blocked> + %77 = ttg.local_alloc %76 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %78 = ttng.warp_group_dot %72, %77, %arg13, %52 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> + %79 = arith.addi %51, %c1_i32 : i32 + %80 = arith.cmpi eq, %49, %46 : i64 + %81 = scf.if %80 -> (i32) { + %82 = tt.expand_dims %53#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %83 = arith.muli %32, %82 : tensor<128x1xi32, #blocked2> + %84 = tt.addptr %34, %83 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %85 = tt.expand_dims %53#2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %86 = tt.broadcast %84 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %87 = tt.broadcast %85 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %88 = tt.addptr %86, %87 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %89 = arith.cmpi slt, %82, %36 : tensor<128x1xi32, #blocked2> + %90 = arith.cmpi slt, %85, %38 : tensor<1x256xi32, #blocked2> + %91 = tt.broadcast %89 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %92 = tt.broadcast %90 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %93 = arith.andi %91, %92 : tensor<128x256xi1, #blocked2> + %94 = arith.truncf %78 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %95 = ttg.convert_layout %94 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %88, %95, %93 : tensor<128x256x!tt.ptr, #blocked2> + %96 = arith.addi %arg11, %c132_i32 : i32 + scf.yield %96 : i32 + } else { + scf.yield %arg11 : i32 + } + scf.yield %49, %81, %79, %78, %53#0, %53#1, %53#2, %53#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>, i1 } - scf.yield %48, %80, %78, %77, %52#0, %52#1, %52#2, %52#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1 } tt.return } From 90930f0b1d25ce915a2cd3a992aeb11d0460f091 Mon Sep 17 00:00:00 2001 From: Mogball Date: Fri, 24 Jan 2025 13:32:19 -0500 Subject: [PATCH 04/32] persistent with loop fusion --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 172 ++++---- .../Transforms/Pipeliner/AssignLatencies.cpp | 4 - .../Transforms/Pipeliner/PipelineExpander.cpp | 6 +- orig.mlir | 381 ------------------ python/src/passes.cc | 1 + python/tutorials/09-persistent-matmul.py | 11 +- test.mlir | 177 -------- test1.mlir | 175 -------- test2.mlir | 203 ---------- test3.mlir | 197 --------- third_party/amd/backend/compiler.py | 2 + third_party/nvidia/backend/compiler.py | 2 + 12 files changed, 110 insertions(+), 1221 deletions(-) delete mode 100644 orig.mlir delete mode 100644 test.mlir delete mode 100644 test1.mlir delete mode 100644 test2.mlir delete mode 100644 test3.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index 3fb34ad5f40b..c246ce9b9a92 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -1,6 +1,8 @@ #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/RegionUtils.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "llvm/Support/Debug.h" @@ -296,27 +298,27 @@ static unsigned getIntTypeWidth(Type type) { } // Generate IR to compute the number of iterations of a loop. -static Value computeNumIters(OpBuilder &b, scf::ForOp loop) { +static Value computeNumIters(ImplicitLocOpBuilder &b, scf::ForOp loop) { // len(range(lb, ub, step)) = ceildiv(ub - lb, step) // This works even if step is negative. Location loc = loop.getLoc(); Value diff = - b.create(loc, loop.getUpperBound(), loop.getLowerBound()); + b.create(loop.getUpperBound(), loop.getLowerBound()); // Let someone else prove it can be unsigned. - return b.create(loc, diff, loop.getStep()); + return b.create(diff, loop.getStep()); } // Cast an integer or index value to an integer or index `type`, if necessary. -static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, +static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value, Type type) { if (value.getType() == type) return value; if (isa(value.getType()) || isa(type)) - return b.create(loc, type, value); + return b.create(type, value); if (cast(value.getType()).getWidth() > cast(type).getWidth()) - return b.create(loc, type, value); - return b.create(loc, type, value); + return b.create(type, value); + return b.create(type, value); } // Given a one level loop nest in the form @@ -345,11 +347,12 @@ static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, // total_iters = len_i * inner_len // // T = -1 -// i = lbi +// i = lbi - stepi // for _ in range(total_iters): // T = (T + 1) % inner_len // // if T == 0: +// i += stepi // prologue0(i) // j0 = lbj0 // if T >= 0 and T < len_j0: @@ -385,7 +388,6 @@ static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, // // if T == max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - (N + 1): // epilogue(i) -// i += stepi // // This routine can be applied recursively on a loop nest tree, leaf-to-root, to // flatten the loop nest into a single loop. However, this routine only fuses @@ -444,7 +446,9 @@ static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, // } // // Note: the induction variables will be initialized to their lower bound to -// avoid underflow in lbjk - stepjk. +// avoid underflow in lbjk - stepjk, with the exception of the outer loop +// induction variable, which needs to be incremented inside the prologue to +// avoid a dependency on the epilogue. This helps the scheduler behave. // // Any inputs and outputs of the loop bodies would also need to be handled // similarly: initialized as undef if appropriate and carried through the fused @@ -497,7 +501,8 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { unsigned intTyWidth = getIntTypeWidth(outer.getInductionVar().getType()); // Generate the computations of the fused loop bounds. - OpBuilder b(outer); + Location loc = outer.getLoc(); + ImplicitLocOpBuilder b(loc, outer); Value lenOuter = computeNumIters(b, outer); SmallVector lenInners; for (scf::ForOp loop : innerLoops) { @@ -509,9 +514,8 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { intTyWidth = std::min(64u, intTyWidth * 2); auto intTy = b.getIntegerType(intTyWidth); - Location loc = outer.getLoc(); auto intTyCst = [&](int64_t v) { - return b.create(loc, IntegerAttr::get(intTy, v)); + return b.create(IntegerAttr::get(intTy, v)); }; // inner_len = max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - N @@ -521,16 +525,16 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector partialInnerSums; partialInnerSums.push_back(innerLen); for (Value lenInner : lenInners) { - lenInner = castIntIfNecessary(b, loc, lenInner, intTy); - lenInner = b.create(loc, intTyCst(1), lenInner); - innerLen = b.create(loc, innerLen, lenInner); + lenInner = castIntIfNecessary(b, lenInner, intTy); + lenInner = b.create(intTyCst(1), lenInner); + innerLen = b.create(innerLen, lenInner); partialInnerSums.push_back(innerLen); } - innerLen = b.create(loc, innerLen, intTyCst(N)); + innerLen = b.create(innerLen, intTyCst(N)); // total_iters = len_i * inner_len - Value totalIters = b.create( - loc, castIntIfNecessary(b, loc, lenOuter, intTy), innerLen); + Value totalIters = + b.create(castIntIfNecessary(b, lenOuter, intTy), innerLen); // The outputs of the prologue, each epilogue, and all inner loop bodies need // to carried through the fused loop. @@ -561,8 +565,9 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { // T = -1 fusedInits.push_back(intTyCst(-1)); - // i = lbi - fusedInits.push_back(outer.getLowerBound()); + // i = lbi - stepi + fusedInits.push_back( + b.create(outer.getLowerBound(), outer.getStep())); unsigned outerArgsStartIdx = fusedInits.size(); llvm::append_range(fusedInits, outer.getInits()); @@ -571,22 +576,22 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { unsigned ivarStartIdx = fusedInits.size(); for (scf::ForOp loop : innerLoops) { fusedInits.push_back( - b.create(loc, loop.getInductionVar().getType())); + b.create(loop.getInductionVar().getType())); } unsigned innerOutsStartIdx = fusedInits.size(); for (scf::ForOp loop : innerLoops) { for (Type resultType : loop.getResultTypes()) - fusedInits.push_back(b.create(loc, resultType)); + fusedInits.push_back(b.create(resultType)); } unsigned logueOutsStartIdx = fusedInits.size(); for (Logue &logue : logues) { for (Type outputType : logue.getOutputTypes()) - fusedInits.push_back(b.create(loc, outputType)); + fusedInits.push_back(b.create(outputType)); } // for _ in range(total_iters): - auto fused = b.create(loc, intTyCst(0), totalIters, intTyCst(1), - fusedInits); + auto fused = + b.create(intTyCst(0), totalIters, intTyCst(1), fusedInits); // Replace the outer loop args with the args in the fused loop args. for (auto [arg, fusedArg] : llvm::zip(outer.getRegionIterArgs(), @@ -597,12 +602,12 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { // T = (T + 1) % inner_len Value T = fused.getRegionIterArg(0); - T = b.create(loc, T, intTyCst(1)); - T = b.create(loc, T, innerLen); + T = b.create(T, intTyCst(1)); + T = b.create(T, innerLen); - // Replace uses of `i` within the fused loop. - Value i = fused.getRegionIterArg(1); - outer.getInductionVar().replaceAllUsesWith(i); + // `i` is computed inside the first prologue. + Value curI = fused.getRegionIterArg(1); + Value i; assert(partialInnerSums.size() == N + 2); ArrayRef ivars = fused.getRegionIterArgs().slice(ivarStartIdx); @@ -610,15 +615,16 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { ValueRange(fused.getRegionIterArgs()).begin() + innerOutsStartIdx; auto logueOutsIt = ValueRange(fused.getRegionIterArgs()).begin() + logueOutsStartIdx; - SmallVector logueIfs, bodyIfs; + SmallVector prologueIfs, bodyIfs; for (unsigned k = 0; k <= N; ++k) { // if T == max(1, len_j0) + ... max(1, len_jk-1) - k + // [[if k == 0]] i += stepi // prologuek(i) // jk = lbjk Value innerStartT = - b.create(loc, partialInnerSums[k], intTyCst(k)); + b.create(partialInnerSums[k], intTyCst(k)); Value prologueCond = - b.create(loc, arith::CmpIPredicate::eq, T, innerStartT); + b.create(arith::CmpIPredicate::eq, T, innerStartT); // The `scf.if` outputs will be `jk` and the outputs of prologuek. We also // have to initialize the inner loop iter args. @@ -628,20 +634,32 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector prologueOutTypes{inner.getInductionVar().getType()}; llvm::append_range(prologueOutTypes, prologue.getOutputTypes()); llvm::append_range(prologueOutTypes, inner.getInits().getTypes()); - auto prologueIf = b.create(loc, prologueOutTypes, prologueCond); - logueIfs.push_back(prologueIf); + if (k == 0) + prologueOutTypes.push_back(curI.getType()); + auto prologueIf = b.create(prologueOutTypes, prologueCond); + prologueIfs.push_back(prologueIf); // Splice prologuek into the `then` region. Block *thenBlock = b.createBlock(&prologueIf.getThenRegion()); prologue.moveBefore(thenBlock, thenBlock->end()); + if (k == 0) { + // Increment `i` and replace its uses inside the prologue. + b.setInsertionPointToStart(thenBlock); + i = b.create(curI, outer.getStep()); + mlir::replaceAllUsesInRegionWith(outer.getInductionVar(), i, + prologueIf.getThenRegion()); + } + // Yield the initialized jk, the prologue outputs, and the initial values of // the inner loop. b.setInsertionPointToEnd(thenBlock); SmallVector thenOuts{inner.getLowerBound()}; llvm::append_range(thenOuts, prologue.getOutputs()); llvm::append_range(thenOuts, inner.getInits()); - b.create(loc, thenOuts); + if (k == 0) + thenOuts.push_back(i); + b.create(thenOuts); // In the `else` region, just yield the last values of jk, the outputs, and // the iter args. @@ -651,8 +669,10 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector elseOuts{lastJk}; elseOuts.append(logueOutsIt, logueOutsIt + numOuts); elseOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); + if (k == 0) + elseOuts.push_back(curI); logueOutsIt += numOuts; - b.create(loc, elseOuts); + b.create(elseOuts); // The results of the `scf.if` become the values of jk and the prologue // outputs for the rest of the fused loop. @@ -665,6 +685,11 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { for (auto [init, iterArg] : llvm::zip(prologueInits, inner.getRegionIterArgs())) iterArg.replaceAllUsesWith(init); + // Replace uses of `i` elsewhere with the prologue result. + if (k == 0) { + i = prologueIf.getResults().back(); + outer.getInductionVar().replaceAllUsesWith(i); + } // if T >= max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k // and T < max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k + @@ -673,17 +698,16 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { // jk += stepjk b.setInsertionPointAfter(prologueIf); Value innerEndT = b.create( - loc, innerStartT, castIntIfNecessary(b, loc, lenInners[k], intTy)); + innerStartT, castIntIfNecessary(b, lenInners[k], intTy)); Value ge = - b.create(loc, arith::CmpIPredicate::sge, T, innerStartT); - Value lt = - b.create(loc, arith::CmpIPredicate::slt, T, innerEndT); - Value bodyCond = b.create(loc, ge, lt); + b.create(arith::CmpIPredicate::sge, T, innerStartT); + Value lt = b.create(arith::CmpIPredicate::slt, T, innerEndT); + Value bodyCond = b.create(ge, lt); // The outputs will be the outputs of the inner loop body and the next jk. SmallVector bodyOutTypes{jk.getType()}; llvm::append_range(bodyOutTypes, inner->getResultTypes()); - auto bodyIf = b.create(loc, bodyOutTypes, bodyCond); + auto bodyIf = b.create(bodyOutTypes, bodyCond); bodyIfs.push_back(bodyIf); // Splice bodyk into the `then` region. @@ -692,7 +716,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { auto yield = cast(bodyIf.getThenRegion().front().getTerminator()); b.setInsertionPoint(yield); - Value nextJk = b.create(loc, jk, inner.getStep()); + Value nextJk = b.create(jk, inner.getStep()); yield->insertOperands(0, nextJk); // The `else` region just forwards the values. @@ -700,7 +724,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector bodyForwardedOuts{jk}; bodyForwardedOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); bodyOutsIt += inner->getNumResults(); - b.create(loc, bodyForwardedOuts); + b.create(bodyForwardedOuts); // Now we can replace the results of the inner loop with the outputs of the // body if. @@ -712,7 +736,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { if (inner->hasAttr(kMustExecuteAttrName)) { b.setInsertionPoint(bodyIf); bodyIf.getConditionMutable().assign( - b.create(loc, b.getBoolAttr(true))); + b.create(b.getBoolAttr(true))); } // Move the insertion point for the next iteration. @@ -721,35 +745,28 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { // if T == len_j0 + len_j1 + ... + len_jN - N - 1: // epilogue(i) - // i += stepi Logue &epilogue = logues.back(); - auto epilogueCond = b.create( - loc, arith::CmpIPredicate::eq, T, - b.create(loc, innerLen, intTyCst(1))); - SmallVector epilogueOutTypes{i.getType()}; - llvm::append_range(epilogueOutTypes, epilogue.getOutputTypes()); - auto epilogueIf = b.create(loc, epilogueOutTypes, epilogueCond); - logueIfs.push_back(epilogueIf); + auto epilogueCond = + b.create(arith::CmpIPredicate::eq, T, + b.create(innerLen, intTyCst(1))); + auto epilogueIf = + b.create(epilogue.getOutputTypes(), epilogueCond); Block *thenBlock = b.createBlock(&epilogueIf.getThenRegion()); epilogue.moveBefore(thenBlock, thenBlock->end()); b.setInsertionPointToEnd(thenBlock); - Value nextI = b.create(loc, i, outer.getStep()); - SmallVector thenOuts{nextI}; - llvm::append_range(thenOuts, epilogue.getOutputs()); - b.create(loc, thenOuts); + b.create(epilogue.getOutputs()); b.createBlock(&epilogueIf.getElseRegion()); - SmallVector elseOuts{i}; - elseOuts.append(logueOutsIt, logueOutsIt + epilogue.getNumOutputs()); - b.create(loc, elseOuts); - epilogue.replaceAllUsesWith( - epilogueIf.getResults().slice(1, epilogue.getNumOutputs()), - epilogueIf.getThenRegion()); + SmallVector elseOuts(logueOutsIt, + logueOutsIt + epilogue.getNumOutputs()); + b.create(elseOuts); + epilogue.replaceAllUsesWith(epilogueIf.getResults(), + epilogueIf.getThenRegion()); // Finally, create the yield of the fused loop. - SmallVector outerOuts{T, /*i=*/epilogueIf.getResult(0)}; + SmallVector outerOuts{T, i}; llvm::append_range(outerOuts, outer.getYieldedValues()); for (scf::IfOp bodyIf : bodyIfs) outerOuts.push_back(/*jk=*/bodyIf.getResult(0)); @@ -758,13 +775,14 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { bodyIf.getResults().slice(1, loop.getNumResults())); loop.erase(); } - for (auto [logueIf, logue] : llvm::zip(logueIfs, logues)) { + for (auto [logueIf, logue] : llvm::zip(prologueIfs, llvm::drop_end(logues))) { llvm::append_range(outerOuts, logueIf.getResults().slice(1, logue.getNumOutputs())); } + llvm::append_range(outerOuts, epilogue.getOutputs()); b.setInsertionPointToEnd(fused.getBody()); - b.create(loc, outerOuts); + b.create(outerOuts); outer.replaceAllUsesWith( fused.getResults().slice(outerArgsStartIdx, outer.getNumResults())); outer.erase(); @@ -796,8 +814,9 @@ static bool shouldFuse(const LoopNest &nest) { return false; scf::ForOp innerLoop = nest.root->children.front()->loop; - return llvm::any_of(innerLoop.getOps(), - [](Operation &op) { return isa(op); }); + return llvm::any_of(innerLoop.getOps(), [](Operation &op) { + return op.hasTrait(); + }); } // Speculate the length of the inner loop such that the loop is known to execute @@ -827,24 +846,23 @@ static LogicalResult speculateInnerLoopLength(const LoopNest &nest, op->moveBefore(outerLoop); // Mark the inner loop. - OpBuilder b(outerLoop); + ImplicitLocOpBuilder b(loc, outerLoop); innerLoop->setAttr(kMustExecuteAttrName, b.getUnitAttr()); // Speculate on whether the length of the inner loop is zero. Value lenInner = computeNumIters(b, innerLoop); auto zeroAttr = IntegerAttr::get(lenInner.getType(), 0); Value innerLoopEmpty = - b.create(loc, arith::CmpIPredicate::eq, lenInner, - b.create(loc, zeroAttr)); - auto ifOp = - b.create(loc, outerLoop.getResultTypes(), innerLoopEmpty); + b.create(arith::CmpIPredicate::eq, lenInner, + b.create(zeroAttr)); + auto ifOp = b.create(outerLoop.getResultTypes(), innerLoopEmpty); // In the `then` branch, the inner loop does not execute. Clone the loop nest // into it and remove the inner loop. mlir::IRMapping map; b.createBlock(&ifOp.getThenRegion()); auto newLoop = cast(b.clone(*outerLoop, map)); - b.create(loc, newLoop.getResults()); + b.create(newLoop.getResults()); auto newInnerLoop = cast(map.lookup(innerLoop)); newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits()); newInnerLoop.erase(); @@ -854,7 +872,7 @@ static LogicalResult speculateInnerLoopLength(const LoopNest &nest, Block *block = b.createBlock(&ifOp.getElseRegion()); outerLoop->remove(); b.insert(outerLoop); - b.create(loc, outerLoop.getResults()); + b.create(outerLoop.getResults()); return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp index 5226c8e831a4..20d5d418fdfa 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -125,11 +125,9 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, [&](Operation *op, Operation *finalUser, int distance) { if (!seen.insert(op).second || excluded.count(op)) return; - op->dump(); if (isa(op)) { if (!isPipeliningBeneficial(op, finalUser, axisInfoAnalysis)) return; - op->dump(); if (loadOpToIndLevel.count(op)) { int level = loadOpToIndLevel[op]; if (level != distance) { @@ -170,7 +168,6 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, continue; seenDot = true; seen.clear(); - op.dump(); dfs(&op, &op, 0); } @@ -232,7 +229,6 @@ DenseMap assignLatencies(ModuleOp moduleOp, DenseMap opLatency; for (auto forOp : loops) { - forOp.dump(); if (hasLatenciesAssigned(forOp)) { assignUserProvidedLatencies(forOp, opLatency); continue; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 20fcba4d7321..d7ff515269ae 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -248,7 +248,11 @@ bool LoopPipelinerInternal::verifySchedule() { continue; int64_t producerCycle = it->second; if (consumerCycle < producerCycle - numCylesPerIter * distance) { - consumer->emitError("operation scheduled before its operands"); + InFlightDiagnostic diag = + consumer->emitError("operation scheduled before its operands"); + diag.attachNote(producer->getLoc()) + .append("operand defined here: ") + .appendOp(*producer, OpPrintingFlags().printGenericOpForm()); return false; } } diff --git a/orig.mlir b/orig.mlir deleted file mode 100644 index c6f2580f8198..000000000000 --- a/orig.mlir +++ /dev/null @@ -1,381 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { - %c2_i32 = arith.constant 2 : i32 loc(#loc1) - %c3_i32 = arith.constant 3 : i32 loc(#loc1) - %false = arith.constant false loc(#loc1) - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) - %c256_i32 = arith.constant 256 : i32 loc(#loc1) - %c128_i32 = arith.constant 128 : i32 loc(#loc1) - %c0_i32 = arith.constant 0 : i32 loc(#loc1) - %c8_i32 = arith.constant 8 : i32 loc(#loc1) - %c-1_i32 = arith.constant -1 : i32 loc(#loc1) - %c1_i32 = arith.constant 1 : i32 loc(#loc1) - %c132_i32 = arith.constant 132 : i32 loc(#loc1) - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> loc(#loc1) - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> loc(#loc1) - %c64_i32 = arith.constant 64 : i32 loc(#loc1) - %c127_i32 = arith.constant 127 : i32 loc(#loc1) - %c255_i32 = arith.constant 255 : i32 loc(#loc1) - %c63_i32 = arith.constant 63 : i32 loc(#loc1) - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc80) - %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc81) - %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc82) - %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc83) - %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc84) - %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc85) - %7 = arith.muli %2, %4 : i32 loc(#loc8) - %8 = arith.divsi %7, %c132_i32 : i32 loc(#loc9) - %9 = arith.remsi %7, %c132_i32 : i32 loc(#loc10) - %10 = arith.cmpi slt, %0, %9 : i32 loc(#loc11) - %11 = scf.if %10 -> (i32) { - %122 = arith.addi %8, %c1_i32 : i32 loc(#loc13) - scf.yield %122 : i32 loc(#loc13) - } else { - scf.yield %8 : i32 loc(#loc1) - } loc(#loc12) - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc14) - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc14) - %14 = arith.muli %4, %c8_i32 : i32 loc(#loc15) - %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc16) - %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc16) - %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc17) - %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc17) - %19 = arith.muli %6, %11 : i32 loc(#loc18) - %20 = arith.subi %6, %c1_i32 : i32 loc(#loc19) - %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc20) - %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc21) - %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc22) - %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc23) - %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc24) - %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc25) - %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc26) - %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc27) - %29 = arith.cmpi sgt, %19, %c0_i32 : i32 loc(#loc28) - %30 = arith.divsi %0, %14 : i32 loc(#loc29) - %31 = arith.muli %30, %c8_i32 : i32 loc(#loc30) - %32 = arith.subi %2, %31 : i32 loc(#loc31) - %33 = arith.minsi %32, %c8_i32 : i32 loc(#loc32) - %34 = arith.remsi %0, %33 : i32 loc(#loc33) - %35 = arith.addi %31, %34 : i32 loc(#loc34) - %36 = arith.remsi %0, %14 : i32 loc(#loc35) - %37 = arith.divsi %36, %33 : i32 loc(#loc36) - %38 = arith.muli %35, %c128_i32 : i32 loc(#loc37) - %39 = arith.muli %37, %c256_i32 : i32 loc(#loc38) - %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) - %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) - %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) - %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> loc(#loc46) - %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) - %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) - %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> loc(#loc22) - %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> loc(#loc48) - %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) - %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> loc(#loc49) - %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) - %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) - %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> loc(#loc28) - %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %69 = ttg.async_commit_group %68 loc(#loc26) - %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> loc(#loc50) - %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) - %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) - %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> loc(#loc28) - %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %77 = ttg.async_commit_group %76 loc(#loc27) - %78 = arith.cmpi sgt, %19, %c1_i32 : i32 loc(#loc28) - %79 = arith.cmpi ne, %20, %c0_i32 : i32 loc(#loc86) - %80 = arith.extui %79 : i1 to i32 loc(#loc51) - %81 = arith.cmpi eq, %80, %c0_i32 : i32 loc(#loc53) - %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %122 = arith.addi %0, %c132_i32 : i32 loc(#loc55) - %123 = arith.divsi %122, %14 : i32 loc(#loc29) - %124 = arith.muli %123, %c8_i32 : i32 loc(#loc30) - %125 = arith.subi %2, %124 : i32 loc(#loc31) - %126 = arith.minsi %125, %c8_i32 : i32 loc(#loc32) - %127 = arith.remsi %122, %126 : i32 loc(#loc33) - %128 = arith.addi %124, %127 : i32 loc(#loc34) - %129 = arith.remsi %122, %14 : i32 loc(#loc35) - %130 = arith.divsi %129, %126 : i32 loc(#loc36) - %131 = arith.muli %128, %c128_i32 : i32 loc(#loc37) - %132 = arith.muli %130, %c256_i32 : i32 loc(#loc38) - %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) - %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - } else { - scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - } loc(#loc54) - %83 = arith.muli %80, %c64_i32 : i32 loc(#loc56) - %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) - %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) - %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) - %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> loc(#loc46) - %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) - %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) - %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) - %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> loc(#loc22) - %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> loc(#loc48) - %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) - %102 = arith.subi %arg5, %83 : i32 loc(#loc60) - %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> loc(#loc49) - %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) - %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) - %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> loc(#loc28) - %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %110 = ttg.async_commit_group %109 loc(#loc26) - %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> loc(#loc50) - %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) - %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) - %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> loc(#loc28) - %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %118 = ttg.async_commit_group %117 loc(#loc27) - %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { - %122 = arith.subi %19, %c2_i32 : i32 loc(#loc28) - %123 = arith.cmpi slt, %arg9, %122 : i32 loc(#loc28) - %124 = arith.cmpi eq, %arg10, %20 : i32 loc(#loc52) - %125 = arith.addi %arg10, %c1_i32 : i32 loc(#loc61) - %126 = arith.select %124, %c0_i32, %125 : i32 loc(#loc51) - %127 = arith.cmpi eq, %126, %c0_i32 : i32 loc(#loc53) - %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %178 = arith.addi %arg11, %c132_i32 : i32 loc(#loc55) - %179 = arith.divsi %178, %14 : i32 loc(#loc29) - %180 = arith.muli %179, %c8_i32 : i32 loc(#loc30) - %181 = arith.subi %2, %180 : i32 loc(#loc31) - %182 = arith.minsi %181, %c8_i32 : i32 loc(#loc32) - %183 = arith.remsi %178, %182 : i32 loc(#loc33) - %184 = arith.addi %180, %183 : i32 loc(#loc34) - %185 = arith.remsi %178, %14 : i32 loc(#loc35) - %186 = arith.divsi %185, %182 : i32 loc(#loc36) - %187 = arith.muli %184, %c128_i32 : i32 loc(#loc37) - %188 = arith.muli %186, %c256_i32 : i32 loc(#loc38) - %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) - %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - } else { - scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - } loc(#loc54) - %129 = arith.addi %arg19, %c1_i32 : i32 loc(#loc28) - %130 = arith.cmpi slt, %129, %c3_i32 : i32 loc(#loc28) - %131 = arith.select %130, %129, %c0_i32 : i32 loc(#loc28) - %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %133 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc26) - %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc62) - %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) - %137 = arith.addi %arg18, %c1_i32 : i32 loc(#loc28) - %138 = arith.cmpi slt, %137, %c3_i32 : i32 loc(#loc28) - %139 = arith.select %138, %137, %c0_i32 : i32 loc(#loc28) - %140 = arith.muli %126, %c64_i32 : i32 loc(#loc56) - %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) - %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) - %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) - %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> loc(#loc46) - %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) - %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) - %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) - %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> loc(#loc22) - %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> loc(#loc48) - %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) - %159 = arith.subi %arg5, %140 : i32 loc(#loc60) - %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> loc(#loc49) - %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) - %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) - %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> loc(#loc28) - %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %167 = ttg.async_commit_group %166 loc(#loc26) - %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> loc(#loc50) - %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) - %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) - %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> loc(#loc28) - %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %175 = ttg.async_commit_group %174 loc(#loc27) - %176 = arith.cmpi eq, %arg22, %20 : i32 loc(#loc63) - %177 = arith.cmpi ne, %arg22, %20 : i32 loc(#loc87) - scf.if %176 { - %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) - %179 = arith.muli %arg24, %c128_i32 : i32 loc(#loc65) - %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) - %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) - %182 = arith.muli %arg26, %c256_i32 : i32 loc(#loc67) - %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) - %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) - %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc69) - %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc70) - %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> loc(#loc70) - %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc71) - %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc71) - %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc72) - %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc73) - %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc73) - %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc73) - %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc74) - %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> loc(#loc74) - %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc75) - %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> loc(#loc75) - %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) - %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) - %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> loc(#loc76) - %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc77) - %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc78) - tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc78) - } loc(#loc64) - scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 loc(#loc28) - } loc(#loc28) - %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc28) - %121 = ttg.async_wait {num = 0 : i32} loc(#loc28) - ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc28) - ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc28) - tt.return loc(#loc79) - } loc(#loc) -} loc(#loc) -#loc1 = loc(unknown) -#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) -#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) -#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) -#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) -#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) -#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) -#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) -#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":172:32) -#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:31) -#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:19) -#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:7) -#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:24) -#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:35) -#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":181:38) -#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:27) -#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:27) -#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:32) -#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:38) -#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:45) -#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:26) -#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:75) -#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:26) -#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:49) -#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:49) -#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:20) -#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:20) -#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:22) -#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:34) -#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) -#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:43) -#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:56) -#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:45) -#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:35) -#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:31) -#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:52) -#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":200:30) -#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":201:30) -#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":202:32) -#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:32) -#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:41) -#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:53) -#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:41) -#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:53) -#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:34) -#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:57) -#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:64) -#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:56) -#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:60) -#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:60) -#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:44) -#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:28) -#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:17) -#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:11) -#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:23) -#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:22) -#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:37) -#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:64) -#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:33) -#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:64) -#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:49) -#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:35) -#loc63 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:17) -#loc64 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:11) -#loc65 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:30) -#loc66 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:45) -#loc67 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:30) -#loc68 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:45) -#loc69 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:49) -#loc70 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:41) -#loc71 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:29) -#loc72 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:80) -#loc73 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:60) -#loc74 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:41) -#loc75 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:66) -#loc76 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:47) -#loc77 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":224:35) -#loc78 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":225:29) -#loc79 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:4) -#loc80 = loc(callsite(#loc3 at #loc4)) -#loc81 = loc(callsite(#loc5 at #loc4)) -#loc82 = loc(callsite(#loc3 at #loc6)) -#loc83 = loc(callsite(#loc5 at #loc6)) -#loc84 = loc(callsite(#loc3 at #loc7)) -#loc85 = loc(callsite(#loc5 at #loc7)) -#loc86 = loc(fused[#loc51, #loc52]) -#loc87 = loc(fused[#loc64, #loc63]) - diff --git a/python/src/passes.cc b/python/src/passes.cc index b0efc3cb884b..619ece2e3455 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -71,6 +71,7 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUCombineTensorSelectAndIf); ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_WRAPPER_0("add_fuse_nested_loops", createTritonGPUFuseNestedLoops); ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling", createTritonGPULoopScheduling, int); ADD_PASS_WRAPPER_0("add_coalesce_async_copy", diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 0d776ba0f3fd..94067cd6b0f2 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -259,7 +259,6 @@ def matmul_persistent(a, b): num_warps=configs[dtype]["num_warps"], # grid=grid ) - print(kernel.asm["ttir"]) return c @@ -610,8 +609,8 @@ def show_profile(precision, profile_name): validate(32, 32, 32, dtype) validate(8192, 8192, 512, dtype) - #proton.start("matmul", hook="triton") - #for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - # bench(K, dtype) - #proton.finalize() - #show_profile(args.prec, "matmul") + proton.start("matmul", hook="triton") + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench(K, dtype) + proton.finalize() + show_profile(args.prec, "matmul") diff --git a/test.mlir b/test.mlir deleted file mode 100644 index 345d793c2124..000000000000 --- a/test.mlir +++ /dev/null @@ -1,177 +0,0 @@ -#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) -module { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32> loc(#loc1) - %c63_i32 = arith.constant 63 : i32 loc(#loc1) - %c255_i32 = arith.constant 255 : i32 loc(#loc1) - %c127_i32 = arith.constant 127 : i32 loc(#loc1) - %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> loc(#loc1) - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> loc(#loc1) - %c1_i32 = arith.constant 1 : i32 loc(#loc1) - %c0_i32 = arith.constant 0 : i32 loc(#loc1) - %c132_i32 = arith.constant 132 : i32 loc(#loc1) - %c64_i32 = arith.constant 64 : i32 loc(#loc1) - %cst_2 = arith.constant dense<0> : tensor<256xi32> loc(#loc1) - %cst_3 = arith.constant dense<0> : tensor<128xi32> loc(#loc1) - %c256_i32 = arith.constant 256 : i32 loc(#loc1) - %c128_i32 = arith.constant 128 : i32 loc(#loc1) - %c8_i32 = arith.constant 8 : i32 loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc63) - %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc64) - %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc65) - %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc66) - %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc67) - %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc68) - %7 = arith.muli %2, %4 : i32 loc(#loc8) - %8 = arith.muli %4, %c8_i32 : i32 loc(#loc9) - %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc10) - %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc11) - %11 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc12) - %12 = tt.splat %arg3 : i32 -> tensor<128xi32> loc(#loc13) - %13 = tt.splat %arg4 : i32 -> tensor<256xi32> loc(#loc14) - %14 = tt.splat %arg6 : i32 -> tensor<128x1xi32> loc(#loc15) - %15 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> loc(#loc16) - %16 = tt.splat %arg7 : i32 -> tensor<1x256xi32> loc(#loc17) - %17 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> loc(#loc18) - %18 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc19) - %19 = tt.expand_dims %9 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc20) - %20 = tt.splat %arg8 : i32 -> tensor<128x1xi32> loc(#loc21) - %21 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc22) - %22 = tt.splat %arg3 : i32 -> tensor<128x1xi32> loc(#loc23) - %23 = tt.splat %arg4 : i32 -> tensor<1x256xi32> loc(#loc24) - scf.for %arg9 = %0 to %7 step %c132_i32 : i32 { - %24 = arith.divsi %arg9, %8 : i32 loc(#loc26) - %25 = arith.muli %24, %c8_i32 : i32 loc(#loc27) - %26 = arith.subi %2, %25 : i32 loc(#loc28) - %27 = arith.minsi %26, %c8_i32 : i32 loc(#loc29) - %28 = arith.remsi %arg9, %27 : i32 loc(#loc30) - %29 = arith.addi %25, %28 : i32 loc(#loc31) - %30 = arith.remsi %arg9, %8 : i32 loc(#loc32) - %31 = arith.divsi %30, %27 : i32 loc(#loc33) - %32 = arith.muli %29, %c128_i32 : i32 loc(#loc34) - %33 = arith.muli %31, %c256_i32 : i32 loc(#loc35) - %34 = tt.splat %32 : i32 -> tensor<128xi32> loc(#loc36) - %35 = arith.addi %34, %10 : tensor<128xi32> loc(#loc36) - %36 = tt.splat %33 : i32 -> tensor<256xi32> loc(#loc37) - %37 = arith.addi %36, %11 : tensor<256xi32> loc(#loc37) - %38 = arith.cmpi slt, %35, %12 : tensor<128xi32> loc(#loc13) - %39 = arith.select %38, %35, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> loc(#loc38) - %40 = arith.cmpi slt, %37, %13 : tensor<256xi32> loc(#loc14) - %41 = arith.select %40, %37, %cst_2 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> loc(#loc39) - %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc40) - %43 = arith.muli %42, %14 : tensor<128x1xi32> loc(#loc15) - %44 = tt.broadcast %43 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc41) - %45 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc42) - %46 = arith.muli %45, %16 : tensor<1x256xi32> loc(#loc17) - %47 = tt.broadcast %46 : tensor<1x256xi32> -> tensor<64x256xi32> loc(#loc43) - %48 = scf.for %arg10 = %c0_i32 to %6 step %c1_i32 iter_args(%arg11 = %cst) -> (tensor<128x256xf32>) : i32 { - %62 = arith.muli %arg10, %c64_i32 : i32 loc(#loc45) - %63 = tt.splat %62 : i32 -> tensor<64xi32> loc(#loc46) - %64 = arith.addi %63, %9 : tensor<64xi32> loc(#loc46) - %65 = tt.expand_dims %64 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc47) - %66 = tt.broadcast %65 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc41) - %67 = arith.addi %44, %66 : tensor<128x64xi32> loc(#loc41) - %68 = tt.addptr %15, %67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc16) - %69 = tt.expand_dims %64 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc48) - %70 = tt.broadcast %69 : tensor<64x1xi32> -> tensor<64x256xi32> loc(#loc43) - %71 = arith.addi %70, %47 : tensor<64x256xi32> loc(#loc43) - %72 = tt.addptr %17, %71 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> loc(#loc18) - %73 = arith.subi %arg5, %62 : i32 loc(#loc49) - %74 = tt.splat %73 : i32 -> tensor<1x64xi32> loc(#loc50) - %75 = arith.cmpi slt, %18, %74 : tensor<1x64xi32> loc(#loc50) - %76 = tt.broadcast %75 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc51) - %77 = tt.load %68, %76, %cst_1 : tensor<128x64x!tt.ptr> loc(#loc51) - %78 = tt.splat %73 : i32 -> tensor<64x1xi32> loc(#loc52) - %79 = arith.cmpi slt, %19, %78 : tensor<64x1xi32> loc(#loc52) - %80 = tt.broadcast %79 : tensor<64x1xi1> -> tensor<64x256xi1> loc(#loc53) - %81 = tt.load %72, %80, %cst_0 : tensor<64x256x!tt.ptr> loc(#loc53) - %82 = tt.dot %77, %81, %arg11, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> loc(#loc54) - scf.yield %82 : tensor<128x256xf32> loc(#loc55) - } loc(#loc44) - %49 = tt.expand_dims %35 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc56) - %50 = arith.muli %20, %49 : tensor<128x1xi32> loc(#loc21) - %51 = tt.addptr %21, %50 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc22) - %52 = tt.expand_dims %37 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc57) - %53 = tt.broadcast %51 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> loc(#loc58) - %54 = tt.broadcast %52 : tensor<1x256xi32> -> tensor<128x256xi32> loc(#loc58) - %55 = tt.addptr %53, %54 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> loc(#loc58) - %56 = arith.cmpi slt, %49, %22 : tensor<128x1xi32> loc(#loc23) - %57 = arith.cmpi slt, %52, %23 : tensor<1x256xi32> loc(#loc24) - %58 = tt.broadcast %56 : tensor<128x1xi1> -> tensor<128x256xi1> loc(#loc59) - %59 = tt.broadcast %57 : tensor<1x256xi1> -> tensor<128x256xi1> loc(#loc59) - %60 = arith.andi %58, %59 : tensor<128x256xi1> loc(#loc59) - %61 = arith.truncf %48 : tensor<128x256xf32> to tensor<128x256xf16> loc(#loc60) - tt.store %55, %61, %60 : tensor<128x256x!tt.ptr> loc(#loc61) - } loc(#loc25) - tt.return loc(#loc62) - } loc(#loc) -} loc(#loc) -#loc1 = loc(unknown) -#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) -#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) -#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) -#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) -#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) -#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) -#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) -#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":171:38) -#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:35) -#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:41) -#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:41) -#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:37) -#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:37) -#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:49) -#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:30) -#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:79) -#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:30) -#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:53) -#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:53) -#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:37) -#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:25) -#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:37) -#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:62) -#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:47) -#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":176:30) -#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":177:33) -#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:39) -#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:52) -#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:41) -#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:31) -#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:27) -#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:48) -#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":182:26) -#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":183:26) -#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:28) -#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:28) -#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:49) -#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:49) -#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:38) -#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:61) -#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:68) -#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:60) -#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:24) -#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:26) -#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:41) -#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:68) -#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) -#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:68) -#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:64) -#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:24) -#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:64) -#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:24) -#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:39) -#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:12) -#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:45) -#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:76) -#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:56) -#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:43) -#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:31) -#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:25) -#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:4) -#loc63 = loc(callsite(#loc3 at #loc4)) -#loc64 = loc(callsite(#loc5 at #loc4)) -#loc65 = loc(callsite(#loc3 at #loc6)) -#loc66 = loc(callsite(#loc5 at #loc6)) -#loc67 = loc(callsite(#loc3 at #loc7)) -#loc68 = loc(callsite(#loc5 at #loc7)) diff --git a/test1.mlir b/test1.mlir deleted file mode 100644 index 691cd743c1bb..000000000000 --- a/test1.mlir +++ /dev/null @@ -1,175 +0,0 @@ -module { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16> - %0 = ub.poison : tensor<64x256xi32> - %1 = ub.poison : tensor<128x64xi32> - %2 = ub.poison : tensor<256xi32> - %3 = ub.poison : tensor<128xi32> - %4 = ub.poison : tensor<128x256xf32> - %5 = ub.poison : i32 - %c-1_i64 = arith.constant -1 : i64 - %c1_i64 = arith.constant 1 : i64 - %c0_i64 = arith.constant 0 : i64 - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32> - %c63_i32 = arith.constant 63 : i32 - %c255_i32 = arith.constant 255 : i32 - %c127_i32 = arith.constant 127 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c132_i32 = arith.constant 132 : i32 - %c64_i32 = arith.constant 64 : i32 - %cst_3 = arith.constant dense<0> : tensor<256xi32> - %cst_4 = arith.constant dense<0> : tensor<128xi32> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %6 = tt.get_program_id x : i32 - %7 = arith.addi %arg3, %c127_i32 : i32 - %8 = arith.divsi %7, %c128_i32 : i32 - %9 = arith.addi %arg4, %c255_i32 : i32 - %10 = arith.divsi %9, %c256_i32 : i32 - %11 = arith.addi %arg5, %c63_i32 : i32 - %12 = arith.divsi %11, %c64_i32 : i32 - %13 = arith.muli %8, %10 : i32 - %14 = arith.muli %10, %c8_i32 : i32 - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %18 = tt.splat %arg3 : i32 -> tensor<128xi32> - %19 = tt.splat %arg4 : i32 -> tensor<256xi32> - %20 = tt.splat %arg6 : i32 -> tensor<128x1xi32> - %21 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> - %22 = tt.splat %arg7 : i32 -> tensor<1x256xi32> - %23 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> - %24 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %25 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> - %26 = tt.splat %arg8 : i32 -> tensor<128x1xi32> - %27 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> - %28 = tt.splat %arg3 : i32 -> tensor<128x1xi32> - %29 = tt.splat %arg4 : i32 -> tensor<1x256xi32> - %30 = arith.cmpi eq, %12, %c0_i32 : i32 - scf.if %30 { - scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { - %31 = arith.divsi %arg9, %14 : i32 - %32 = arith.muli %31, %c8_i32 : i32 - %33 = arith.subi %8, %32 : i32 - %34 = arith.minsi %33, %c8_i32 : i32 - %35 = arith.remsi %arg9, %34 : i32 - %36 = arith.addi %32, %35 : i32 - %37 = arith.remsi %arg9, %14 : i32 - %38 = arith.divsi %37, %34 : i32 - %39 = arith.muli %36, %c128_i32 : i32 - %40 = arith.muli %38, %c256_i32 : i32 - %41 = tt.splat %39 : i32 -> tensor<128xi32> - %42 = arith.addi %41, %16 : tensor<128xi32> - %43 = tt.splat %40 : i32 -> tensor<256xi32> - %44 = arith.addi %43, %17 : tensor<256xi32> - %45 = tt.expand_dims %42 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %46 = arith.muli %26, %45 : tensor<128x1xi32> - %47 = tt.addptr %27, %46 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> - %48 = tt.expand_dims %44 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %49 = tt.broadcast %47 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> - %50 = tt.broadcast %48 : tensor<1x256xi32> -> tensor<128x256xi32> - %51 = tt.addptr %49, %50 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %52 = arith.cmpi slt, %45, %28 : tensor<128x1xi32> - %53 = arith.cmpi slt, %48, %29 : tensor<1x256xi32> - %54 = tt.broadcast %52 : tensor<128x1xi1> -> tensor<128x256xi1> - %55 = tt.broadcast %53 : tensor<1x256xi1> -> tensor<128x256xi1> - %56 = arith.andi %54, %55 : tensor<128x256xi1> - tt.store %51, %cst, %56 : tensor<128x256x!tt.ptr> - } - } else { - %31 = arith.subi %13, %6 : i32 - %32 = arith.ceildivsi %31, %c132_i32 : i32 - %33 = arith.extsi %12 : i32 to i64 - %34 = arith.maxsi %33, %c1_i64 : i64 - %35 = arith.extsi %32 : i32 to i64 - %36 = arith.muli %35, %34 : i64 - %37 = arith.subi %34, %c1_i64 : i64 - %38:8 = scf.for %arg9 = %c0_i64 to %36 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %5, %arg13 = %4, %arg14 = %3, %arg15 = %2, %arg16 = %1, %arg17 = %0) -> (i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) : i64 { - %39 = arith.addi %arg10, %c1_i64 : i64 - %40 = arith.remsi %39, %34 : i64 - %41 = arith.cmpi eq, %40, %c0_i64 : i64 - %42 = arith.select %41, %c0_i32, %arg12 : i32 - %43 = arith.select %41, %cst_0, %arg13 : tensor<128x256xf32> - %44:4 = scf.if %41 -> (tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32>) { - %69 = arith.divsi %arg11, %14 : i32 - %70 = arith.muli %69, %c8_i32 : i32 - %71 = arith.subi %8, %70 : i32 - %72 = arith.minsi %71, %c8_i32 : i32 - %73 = arith.remsi %arg11, %72 : i32 - %74 = arith.addi %70, %73 : i32 - %75 = arith.remsi %arg11, %14 : i32 - %76 = arith.divsi %75, %72 : i32 - %77 = arith.muli %74, %c128_i32 : i32 - %78 = arith.muli %76, %c256_i32 : i32 - %79 = tt.splat %77 : i32 -> tensor<128xi32> - %80 = arith.addi %79, %16 : tensor<128xi32> - %81 = tt.splat %78 : i32 -> tensor<256xi32> - %82 = arith.addi %81, %17 : tensor<256xi32> - %83 = arith.cmpi slt, %80, %18 : tensor<128xi32> - %84 = arith.select %83, %80, %cst_4 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> - %85 = arith.cmpi slt, %82, %19 : tensor<256xi32> - %86 = arith.select %85, %82, %cst_3 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> - %87 = tt.expand_dims %84 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %88 = arith.muli %87, %20 : tensor<128x1xi32> - %89 = tt.broadcast %88 : tensor<128x1xi32> -> tensor<128x64xi32> - %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %91 = arith.muli %90, %22 : tensor<1x256xi32> - %92 = tt.broadcast %91 : tensor<1x256xi32> -> tensor<64x256xi32> - scf.yield %80, %82, %89, %92 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> - } else { - scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> - } - %45 = arith.muli %42, %c64_i32 : i32 - %46 = tt.splat %45 : i32 -> tensor<64xi32> - %47 = arith.addi %46, %15 : tensor<64xi32> - %48 = tt.expand_dims %47 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %49 = tt.broadcast %48 : tensor<1x64xi32> -> tensor<128x64xi32> - %50 = arith.addi %44#2, %49 : tensor<128x64xi32> - %51 = tt.addptr %21, %50 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %52 = tt.expand_dims %47 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> - %53 = tt.broadcast %52 : tensor<64x1xi32> -> tensor<64x256xi32> - %54 = arith.addi %53, %44#3 : tensor<64x256xi32> - %55 = tt.addptr %23, %54 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> - %56 = arith.subi %arg5, %45 : i32 - %57 = tt.splat %56 : i32 -> tensor<1x64xi32> - %58 = arith.cmpi slt, %24, %57 : tensor<1x64xi32> - %59 = tt.broadcast %58 : tensor<1x64xi1> -> tensor<128x64xi1> - %60 = tt.load %51, %59, %cst_2 : tensor<128x64x!tt.ptr> - %61 = tt.splat %56 : i32 -> tensor<64x1xi32> - %62 = arith.cmpi slt, %25, %61 : tensor<64x1xi32> - %63 = tt.broadcast %62 : tensor<64x1xi1> -> tensor<64x256xi1> - %64 = tt.load %55, %63, %cst_1 : tensor<64x256x!tt.ptr> - %65 = tt.dot %60, %64, %43, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> - %66 = arith.addi %42, %c1_i32 : i32 - %67 = arith.cmpi eq, %40, %37 : i64 - %68 = scf.if %67 -> (i32) { - %69 = tt.expand_dims %44#0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %70 = arith.muli %26, %69 : tensor<128x1xi32> - %71 = tt.addptr %27, %70 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> - %72 = tt.expand_dims %44#1 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %73 = tt.broadcast %71 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> - %74 = tt.broadcast %72 : tensor<1x256xi32> -> tensor<128x256xi32> - %75 = tt.addptr %73, %74 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %76 = arith.cmpi slt, %69, %28 : tensor<128x1xi32> - %77 = arith.cmpi slt, %72, %29 : tensor<1x256xi32> - %78 = tt.broadcast %76 : tensor<128x1xi1> -> tensor<128x256xi1> - %79 = tt.broadcast %77 : tensor<1x256xi1> -> tensor<128x256xi1> - %80 = arith.andi %78, %79 : tensor<128x256xi1> - %81 = arith.truncf %65 : tensor<128x256xf32> to tensor<128x256xf16> - tt.store %75, %81, %80 : tensor<128x256x!tt.ptr> - %82 = arith.addi %arg11, %c132_i32 : i32 - scf.yield %82 : i32 - } else { - scf.yield %arg11 : i32 - } - scf.yield %40, %68, %66, %65, %44#0, %44#1, %44#2, %44#3 : i64, i32, i32, tensor<128x256xf32>, tensor<128xi32>, tensor<256xi32>, tensor<128x64xi32>, tensor<64x256xi32> - } - } - tt.return - } -} - diff --git a/test2.mlir b/test2.mlir deleted file mode 100644 index 9f3a4304fd97..000000000000 --- a/test2.mlir +++ /dev/null @@ -1,203 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %c0_i64 = arith.constant 0 : i64 - %c1_i64 = arith.constant 1 : i64 - %c-1_i64 = arith.constant -1 : i64 - %0 = ub.poison : i32 - %1 = ub.poison : tensor<128x256xf32, #mma> - %2 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %3 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %4 = ub.poison : tensor<128x64xi32, #blocked1> - %5 = ub.poison : tensor<64x256xi32, #blocked> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked3> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %6 = tt.get_program_id x : i32 - %7 = arith.addi %arg3, %c127_i32 : i32 - %8 = arith.divsi %7, %c128_i32 : i32 - %9 = arith.addi %arg4, %c255_i32 : i32 - %10 = arith.divsi %9, %c256_i32 : i32 - %11 = arith.addi %arg5, %c63_i32 : i32 - %12 = arith.divsi %11, %c64_i32 : i32 - %13 = arith.muli %8, %10 : i32 - %14 = arith.muli %10, %c8_i32 : i32 - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> - %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> - %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %23 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %24 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %25 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %26 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %27 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %28 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %29 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %30 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %31 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked3> - %32 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %33 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked3> - %34 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %35 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked3> - %36 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %37 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked3> - %38 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %39 = arith.cmpi eq, %12, %c0_i32 : i32 - scf.if %39 { - scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { - %40 = arith.divsi %arg9, %14 : i32 - %41 = arith.muli %40, %c8_i32 : i32 - %42 = arith.subi %8, %41 : i32 - %43 = arith.minsi %42, %c8_i32 : i32 - %44 = arith.remsi %arg9, %43 : i32 - %45 = arith.addi %41, %44 : i32 - %46 = arith.remsi %arg9, %14 : i32 - %47 = arith.divsi %46, %43 : i32 - %48 = arith.muli %45, %c128_i32 : i32 - %49 = arith.muli %47, %c256_i32 : i32 - %50 = tt.splat %48 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> - %51 = arith.addi %50, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> - %52 = tt.splat %49 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> - %53 = arith.addi %52, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> - %54 = tt.expand_dims %51 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1xi32, #blocked3> - %55 = arith.muli %31, %54 : tensor<128x1xi32, #blocked3> - %56 = tt.addptr %33, %55 : tensor<128x1x!tt.ptr, #blocked3>, tensor<128x1xi32, #blocked3> - %57 = tt.expand_dims %53 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x256xi32, #blocked3> - %58 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked3> -> tensor<128x256x!tt.ptr, #blocked3> - %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked3> -> tensor<128x256xi32, #blocked3> - %60 = tt.addptr %58, %59 : tensor<128x256x!tt.ptr, #blocked3>, tensor<128x256xi32, #blocked3> - %61 = arith.cmpi slt, %54, %35 : tensor<128x1xi32, #blocked3> - %62 = arith.cmpi slt, %57, %37 : tensor<1x256xi32, #blocked3> - %63 = tt.broadcast %61 : tensor<128x1xi1, #blocked3> -> tensor<128x256xi1, #blocked3> - %64 = tt.broadcast %62 : tensor<1x256xi1, #blocked3> -> tensor<128x256xi1, #blocked3> - %65 = arith.andi %63, %64 : tensor<128x256xi1, #blocked3> - tt.store %60, %cst_1, %65 : tensor<128x256x!tt.ptr, #blocked3> - } - } else { - %40 = arith.subi %13, %6 : i32 - %41 = arith.ceildivsi %40, %c132_i32 : i32 - %42 = arith.extsi %12 : i32 to i64 - %43 = arith.maxsi %42, %c1_i64 : i64 - %44 = arith.extsi %41 : i32 to i64 - %45 = arith.muli %44, %43 : i64 - %46 = arith.subi %43, %c1_i64 : i64 - %true = arith.constant true - %false = arith.constant false - %47:9 = scf.for %arg9 = %c0_i64 to %45 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %0, %arg13 = %1, %arg14 = %4, %arg15 = %5, %arg16 = %3, %arg17 = %2, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>, i1) : i64 { - %48 = arith.addi %arg10, %c1_i64 : i64 - %49 = arith.remsi %48, %43 : i64 - %50 = arith.cmpi eq, %49, %c0_i64 : i64 - %51 = arith.select %50, %c0_i32, %arg12 : i32 - %52 = arith.select %50, %false, %arg18 : i1 - %53:4 = scf.if %50 -> (tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>) { - %82 = arith.divsi %arg11, %14 : i32 - %83 = arith.muli %82, %c8_i32 : i32 - %84 = arith.subi %8, %83 : i32 - %85 = arith.minsi %84, %c8_i32 : i32 - %86 = arith.remsi %arg11, %85 : i32 - %87 = arith.addi %83, %86 : i32 - %88 = arith.remsi %arg11, %14 : i32 - %89 = arith.divsi %88, %85 : i32 - %90 = arith.muli %87, %c128_i32 : i32 - %91 = arith.muli %89, %c256_i32 : i32 - %92 = tt.splat %90 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %93 = tt.splat %90 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %94 = arith.addi %92, %19 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %95 = arith.addi %93, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %96 = tt.splat %91 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %97 = tt.splat %91 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %98 = arith.addi %96, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %99 = arith.addi %97, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %100 = arith.cmpi slt, %94, %23 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %101 = arith.select %100, %94, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %102 = arith.cmpi slt, %98, %24 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %103 = arith.select %102, %98, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %104 = tt.expand_dims %101 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %105 = arith.muli %104, %25 : tensor<128x1xi32, #blocked1> - %106 = tt.broadcast %105 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %107 = tt.expand_dims %103 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %108 = arith.muli %107, %27 : tensor<1x256xi32, #blocked> - %109 = tt.broadcast %108 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - scf.yield %106, %109, %99, %95 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - } else { - scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - } - %54 = arith.muli %51, %c64_i32 : i32 - %55 = tt.splat %54 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %56 = tt.splat %54 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %57 = arith.addi %55, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %58 = arith.addi %56, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %59 = tt.expand_dims %57 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %60 = tt.broadcast %59 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %61 = arith.addi %53#0, %60 : tensor<128x64xi32, #blocked1> - %62 = tt.addptr %26, %61 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %63 = tt.expand_dims %58 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %64 = tt.broadcast %63 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %65 = arith.addi %64, %53#1 : tensor<64x256xi32, #blocked> - %66 = tt.addptr %28, %65 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %67 = arith.subi %arg5, %54 : i32 - %68 = tt.splat %67 : i32 -> tensor<1x64xi32, #blocked1> - %69 = arith.cmpi slt, %29, %68 : tensor<1x64xi32, #blocked1> - %70 = tt.broadcast %69 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %71 = tt.load %62, %70, %cst_2 : tensor<128x64x!tt.ptr, #blocked1> - %72 = ttg.local_alloc %71 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %73 = tt.splat %67 : i32 -> tensor<64x1xi32, #blocked> - %74 = arith.cmpi slt, %30, %73 : tensor<64x1xi32, #blocked> - %75 = tt.broadcast %74 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %76 = tt.load %66, %75, %cst_3 : tensor<64x256x!tt.ptr, #blocked> - %77 = ttg.local_alloc %76 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> - %78 = ttng.warp_group_dot %72, %77, %arg13, %52 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> - %79 = arith.addi %51, %c1_i32 : i32 - %80 = arith.cmpi eq, %49, %46 : i64 - %81 = scf.if %80 -> (i32) { - %82 = tt.expand_dims %53#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %83 = arith.muli %32, %82 : tensor<128x1xi32, #blocked2> - %84 = tt.addptr %34, %83 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %85 = tt.expand_dims %53#2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %86 = tt.broadcast %84 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %87 = tt.broadcast %85 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %88 = tt.addptr %86, %87 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %89 = arith.cmpi slt, %82, %36 : tensor<128x1xi32, #blocked2> - %90 = arith.cmpi slt, %85, %38 : tensor<1x256xi32, #blocked2> - %91 = tt.broadcast %89 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %92 = tt.broadcast %90 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %93 = arith.andi %91, %92 : tensor<128x256xi1, #blocked2> - %94 = arith.truncf %78 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %95 = ttg.convert_layout %94 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %88, %95, %93 : tensor<128x256x!tt.ptr, #blocked2> - %96 = arith.addi %arg11, %c132_i32 : i32 - scf.yield %96 : i32 - } else { - scf.yield %arg11 : i32 - } - scf.yield %49, %81, %79, %78, %53#0, %53#1, %53#2, %53#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>, i1 - } - } - tt.return - } -} - diff --git a/test3.mlir b/test3.mlir deleted file mode 100644 index 9b84523f6053..000000000000 --- a/test3.mlir +++ /dev/null @@ -1,197 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %false = arith.constant false - %true = arith.constant true - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %c0_i64 = arith.constant 0 : i64 - %c1_i64 = arith.constant 1 : i64 - %c-1_i64 = arith.constant -1 : i64 - %0 = ub.poison : i32 - %1 = ub.poison : tensor<128x256xf32, #mma> - %2 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %3 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %4 = ub.poison : tensor<128x64xi32, #blocked1> - %5 = ub.poison : tensor<64x256xi32, #blocked> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked2> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %6 = tt.get_program_id x : i32 - %7 = arith.addi %arg3, %c127_i32 : i32 - %8 = arith.divsi %7, %c128_i32 : i32 - %9 = arith.addi %arg4, %c255_i32 : i32 - %10 = arith.divsi %9, %c256_i32 : i32 - %11 = arith.addi %arg5, %c63_i32 : i32 - %12 = arith.divsi %11, %c64_i32 : i32 - %13 = arith.muli %8, %10 : i32 - %14 = arith.muli %10, %c8_i32 : i32 - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %22 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %23 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %24 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %25 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %26 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %27 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %28 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %29 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %30 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %31 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> - %32 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %33 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %34 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %35 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1> - %36 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %37 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> - %38 = arith.cmpi eq, %12, %c0_i32 : i32 - scf.if %38 { - scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { - %39 = arith.divsi %arg9, %14 : i32 - %40 = arith.muli %39, %c8_i32 : i32 - %41 = arith.subi %8, %40 : i32 - %42 = arith.minsi %41, %c8_i32 : i32 - %43 = arith.remsi %arg9, %42 : i32 - %44 = arith.addi %40, %43 : i32 - %45 = arith.remsi %arg9, %14 : i32 - %46 = arith.divsi %45, %42 : i32 - %47 = arith.muli %44, %c128_i32 : i32 - %48 = arith.muli %46, %c256_i32 : i32 - %49 = tt.splat %47 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %50 = arith.addi %49, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %51 = tt.splat %48 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %52 = arith.addi %51, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %53 = tt.expand_dims %50 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %54 = arith.muli %30, %53 : tensor<128x1xi32, #blocked2> - %55 = tt.addptr %32, %54 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %56 = tt.expand_dims %52 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %57 = tt.broadcast %55 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %58 = tt.broadcast %56 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %59 = tt.addptr %57, %58 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %60 = arith.cmpi slt, %53, %34 : tensor<128x1xi32, #blocked2> - %61 = arith.cmpi slt, %56, %36 : tensor<1x256xi32, #blocked2> - %62 = tt.broadcast %60 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %63 = tt.broadcast %61 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %64 = arith.andi %62, %63 : tensor<128x256xi1, #blocked2> - tt.store %59, %cst_1, %64 : tensor<128x256x!tt.ptr, #blocked2> - } - } else { - %39 = arith.subi %13, %6 : i32 - %40 = arith.ceildivsi %39, %c132_i32 : i32 - %41 = arith.extsi %12 : i32 to i64 - %42 = arith.maxsi %41, %c1_i64 : i64 - %43 = arith.extsi %40 : i32 to i64 - %44 = arith.muli %43, %42 : i64 - %45 = arith.subi %42, %c1_i64 : i64 - %46:9 = scf.for %arg9 = %c0_i64 to %44 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %6, %arg12 = %0, %arg13 = %1, %arg14 = %4, %arg15 = %5, %arg16 = %3, %arg17 = %2, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1) : i64 { - %47 = arith.addi %arg10, %c1_i64 : i64 - %48 = arith.remsi %47, %42 : i64 - %49 = arith.cmpi eq, %48, %c0_i64 : i64 - %50 = arith.select %49, %c0_i32, %arg12 : i32 - %51 = arith.select %49, %false, %arg18 : i1 - %52:4 = scf.if %49 -> (tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>) { - %81 = arith.divsi %arg11, %14 : i32 - %82 = arith.muli %81, %c8_i32 : i32 - %83 = arith.subi %8, %82 : i32 - %84 = arith.minsi %83, %c8_i32 : i32 - %85 = arith.remsi %arg11, %84 : i32 - %86 = arith.addi %82, %85 : i32 - %87 = arith.remsi %arg11, %14 : i32 - %88 = arith.divsi %87, %84 : i32 - %89 = arith.muli %86, %c128_i32 : i32 - %90 = arith.muli %88, %c256_i32 : i32 - %91 = tt.splat %89 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %92 = arith.addi %91, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %93 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %94 = tt.splat %90 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %95 = arith.addi %93, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %96 = arith.addi %94, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %97 = arith.cmpi slt, %92, %22 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %98 = arith.select %97, %92, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %99 = arith.cmpi slt, %95, %23 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %100 = arith.select %99, %95, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %101 = tt.expand_dims %98 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %102 = arith.muli %101, %24 : tensor<128x1xi32, #blocked1> - %103 = tt.broadcast %102 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %104 = tt.expand_dims %100 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %105 = arith.muli %104, %26 : tensor<1x256xi32, #blocked> - %106 = tt.broadcast %105 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - scf.yield %103, %106, %96, %92 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - } else { - scf.yield %arg14, %arg15, %arg16, %arg17 : tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - } - %53 = arith.muli %50, %c64_i32 : i32 - %54 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %55 = tt.splat %53 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %56 = arith.addi %54, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %57 = arith.addi %55, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %58 = tt.expand_dims %56 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %59 = tt.broadcast %58 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %60 = arith.addi %52#0, %59 : tensor<128x64xi32, #blocked1> - %61 = tt.addptr %25, %60 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %62 = tt.expand_dims %57 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %63 = tt.broadcast %62 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %64 = arith.addi %63, %52#1 : tensor<64x256xi32, #blocked> - %65 = tt.addptr %27, %64 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %66 = arith.subi %arg5, %53 : i32 - %67 = tt.splat %66 : i32 -> tensor<1x64xi32, #blocked1> - %68 = arith.cmpi slt, %28, %67 : tensor<1x64xi32, #blocked1> - %69 = tt.broadcast %68 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %70 = tt.load %61, %69, %cst_2 : tensor<128x64x!tt.ptr, #blocked1> - %71 = ttg.local_alloc %70 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %72 = tt.splat %66 : i32 -> tensor<64x1xi32, #blocked> - %73 = arith.cmpi slt, %29, %72 : tensor<64x1xi32, #blocked> - %74 = tt.broadcast %73 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %75 = tt.load %65, %74, %cst_3 : tensor<64x256x!tt.ptr, #blocked> - %76 = ttg.local_alloc %75 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared, #smem> - %77 = ttng.warp_group_dot %71, %76, %arg13, %51 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> - %78 = arith.addi %50, %c1_i32 : i32 - %79 = arith.cmpi eq, %48, %45 : i64 - %80 = scf.if %79 -> (i32) { - %81 = tt.expand_dims %52#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %82 = arith.muli %31, %81 : tensor<128x1xi32, #blocked1> - %83 = tt.addptr %33, %82 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %84 = tt.expand_dims %52#2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> - %85 = tt.broadcast %83 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> - %86 = tt.broadcast %84 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> - %87 = tt.addptr %85, %86 : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> - %88 = arith.cmpi slt, %81, %35 : tensor<128x1xi32, #blocked1> - %89 = arith.cmpi slt, %84, %37 : tensor<1x256xi32, #blocked1> - %90 = tt.broadcast %88 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> - %91 = tt.broadcast %89 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> - %92 = arith.andi %90, %91 : tensor<128x256xi1, #blocked1> - %93 = arith.truncf %77 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %94 = ttg.convert_layout %93 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> - tt.store %87, %94, %92 : tensor<128x256x!tt.ptr, #blocked1> - %95 = arith.addi %arg11, %c132_i32 : i32 - scf.yield %95 : i32 - } else { - scf.yield %arg11 : i32 - } - scf.yield %48, %80, %78, %77, %52#0, %52#1, %52#2, %52#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128x64xi32, #blocked1>, tensor<64x256xi32, #blocked>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, i1 - } - } - tt.return - } -} - diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index cf36320b828c..c29c325cf8b9 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -242,6 +242,8 @@ def make_ttgir(mod, metadata, options): "num_stages == 0. Now it will not happen anymore; " "please update to use num_stages == 2 for " "equivalent behavior in the past.") + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, stream_prefetch) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.insert_instruction_sched_hints(pm) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 960334744384..2411e85dcaa9 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -253,6 +253,8 @@ def make_ttgir(mod, metadata, opt, capability): if capability // 10 >= 8: passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) passes.ttgpuir.add_prefetch(pm) From d6fb02d576261cda19298e94913ae9ebfb8740e6 Mon Sep 17 00:00:00 2001 From: Mogball Date: Fri, 24 Jan 2025 21:59:10 -0500 Subject: [PATCH 05/32] I cry --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 125 ++++- .../Transforms/Pipeliner/PipelineExpander.cpp | 12 + new.mlir | 383 ++++++++++++++ orig.mlir | 386 ++++++++++++++ orig2.mlir | 293 +++++++++++ python/tutorials/09-persistent-matmul.py | 482 +++++++++++++++++- test.mlir | 178 +++++++ test2.mlir | 128 +++++ test3.mlir | 177 +++++++ test4.mlir | 192 +++++++ test5.mlir | 345 +++++++++++++ third_party/amd/backend/compiler.py | 2 - third_party/nvidia/backend/compiler.py | 7 +- 13 files changed, 2681 insertions(+), 29 deletions(-) create mode 100644 new.mlir create mode 100644 orig.mlir create mode 100644 orig2.mlir create mode 100644 test.mlir create mode 100644 test2.mlir create mode 100644 test3.mlir create mode 100644 test4.mlir create mode 100644 test5.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index c246ce9b9a92..a0899815939c 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -234,7 +234,7 @@ static Logue createLogueFrom(llvm::iterator_range ops, // recursively. static bool canHoistLoopBoundComputation(Operation *op) { auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); }; - return isMemoryEffectFree(op) && + return isPure(op) && op->hasTrait() && llvm::all_of(op->getOperandTypes(), isScalar) && llvm::all_of(op->getResultTypes(), isScalar); } @@ -819,6 +819,122 @@ static bool shouldFuse(const LoopNest &nest) { }); } +// Loop-invariant code motion can increase register pressure in combination with +// loop nest fusion. Values hoisted out of the inner loop and in to the prologue +// that are directly used inside the inner loop will need to be added as iter +// args to the fused loop, substantially increasing their liverange. +// +// This function identifies a subgraph of cheap ops that can be sunk and +// determines if doing so will reduce register pressure. +static void sinkHeavyOps(Region &limit, Block *sinkBlock, + Block::iterator sinkBefore, + llvm::iterator_range prologue, + function_ref inSinkRegion, + function_ref shouldSink) { + llvm::SetVector sunkOps; + auto canBeSunk = [&](Operation &op) -> std::pair { + if (!isPure(&op) || op.hasTrait()) + return {false, false}; + // An op can be sunk if all its users are inside the inner loop or are + // marked for sinking. + bool isRoot = true; + for (Operation *user : op.getUsers()) { + if (inSinkRegion(user)) + continue; + isRoot = false; + if (sunkOps.contains(user)) + continue; + return {false, false}; + } + return {true, isRoot}; + }; + + // Find the subgraph of operations that can be sunk. + SmallVector roots; + for (Operation &op : llvm::reverse(prologue)) { + auto [canSink, isRoot] = canBeSunk(op); + if (canSink) + sunkOps.insert(&op); + if (isRoot) + roots.push_back(&op); + } + if (sunkOps.empty()) + return; + + // Analyze the sinking the whole subgraph at once. Breaking up the subgraph is + // a more complicated analysis. + // + // Compute the total size of the fan-ins and fan-outs as the number of + // registers per thread used by the value. This is a heuristic. + MLIRContext *ctx = sunkOps.front()->getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto getSizeEstimate = [&](Type type) { + auto tensor = dyn_cast(type); + if (!tensor) + return 1; + LinearLayout layout = + toLinearLayout(tensor.getShape(), tensor.getEncoding()); + return layout.getInDimSize(kRegister); + }; + + size_t fanOutSize = 0; + for (Operation *root : roots) { + for (Value result : root->getResults()) { + if (result.use_empty()) + continue; + fanOutSize += getSizeEstimate(result.getType()); + } + } + + size_t fanInSize = 0; + DenseSet checked; + for (Operation *op : sunkOps) { + for (Value operand : op->getOperands()) { + // Count each operand only once. + if (!checked.insert(operand).second) + continue; + if (sunkOps.contains(operand.getDefiningOp())) + continue; + if (operand.getParentRegion()->isProperAncestor(&limit)) + continue; + if (llvm::any_of(operand.getUsers(), inSinkRegion)) + continue; + fanInSize += getSizeEstimate(operand.getType()); + } + } + + // Only sink if this will lead to a large reduction. + if (shouldSink(fanInSize, fanOutSize)) { + sunkOps = topologicalSort(sunkOps); + for (Operation *op : sunkOps) + op->moveBefore(sinkBlock, sinkBefore); + } +} + +// Sink ops into the inner loop and from the prologue into the epilogue. +static void sinkHeavyOps(scf::ForOp outerLoop, scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { + Region &limit = outerLoop.getBodyRegion(); + auto inInnerLoop = [&](Operation *op) { + return innerLoop.getBodyRegion().isAncestor(op->getParentRegion()); + }; + //sinkHeavyOps(limit, innerLoop.getBody(), innerLoop.getBody()->begin(), + // {outerLoop.getBody()->begin(), innerLoop->getIterator()}, + // inInnerLoop, [&](size_t fanInSize, size_t fanOutSize) { + // return fanInSize * 4 <= fanOutSize; + // }); + + // Move computations in the prologue that can be done in the epilogue. This is + // always beneficial. + auto inEpilogue = [&](Operation *op) { + return domInfo.properlyDominates(innerLoop, op, /*enclosingOpOk=*/false); + }; + sinkHeavyOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()), + {outerLoop.getBody()->begin(), innerLoop->getIterator()}, + inEpilogue, + [&](size_t fanInSize, size_t fanOutSize) { return true; }); +} + // Speculate the length of the inner loop such that the loop is known to execute // at least once. This way, the inner loop body does not have to be placed // inside a conditional in the fused loop, which interacts better with the @@ -830,6 +946,13 @@ static LogicalResult speculateInnerLoopLength(const LoopNest &nest, scf::ForOp outerLoop = nest.root->loop; scf::ForOp innerLoop = nest.root->children.front()->loop; + // Sink heavy ops first. + sinkHeavyOps(outerLoop, innerLoop, domInfo); + + innerLoop->setAttr(kMustExecuteAttrName, + UnitAttr::get(outerLoop.getContext())); + return success(); + // The inner loop bounds must be outer-loop invariant to speculate from // outside the loop nest. Location loc = innerLoop.getLoc(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index d7ff515269ae..b366da9b8f2e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -285,7 +285,19 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { setValueMapping(arg, operand.get(), 0); } + + // If the incoming value to an iter arg from the loop yield is defined outside + // the loop, then that means the iter arg takes that value for all stages + // after the first stage. auto yield = cast(forOp.getBody()->getTerminator()); + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), yield->getOpOperands())) { + if (forOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) + continue; + for (int64_t i = 1; i < maxStage; ++i) + setValueMapping(arg, operand.get(), i); + } + Location loc = forOp.getLoc(); SmallVector predicates(maxStage); for (int64_t i = 0; i < maxStage; i++) { diff --git a/new.mlir b/new.mlir new file mode 100644 index 000000000000..24bf404cb01f --- /dev/null +++ b/new.mlir @@ -0,0 +1,383 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0) +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0)) attributes {noinline = false} { + %c2_i64 = arith.constant 2 : i64 loc(#loc1) + %c3_i32 = arith.constant 3 : i32 loc(#loc1) + %c-1_i32 = arith.constant -1 : i32 loc(#loc1) + %0 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + %1 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) + %2 = ub.poison : tensor<128x256xf32, #mma> loc(#loc1) + %3 = ub.poison : i32 loc(#loc1) + %c1_i64 = arith.constant 1 : i64 loc(#loc1) + %c0_i64 = arith.constant 0 : i64 loc(#loc1) + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c132_i32 = arith.constant 132 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %c255_i32 = arith.constant 255 : i32 loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %4 = tt.get_program_id x : i32 loc(#loc2) + %5 = arith.addi %arg3, %c127_i32 : i32 loc(#loc59) + %6 = arith.divsi %5, %c128_i32 : i32 loc(#loc60) + %7 = arith.addi %arg4, %c255_i32 : i32 loc(#loc61) + %8 = arith.divsi %7, %c256_i32 : i32 loc(#loc62) + %9 = arith.addi %arg5, %c63_i32 : i32 loc(#loc63) + %10 = arith.divsi %9, %c64_i32 : i32 loc(#loc64) + %11 = arith.muli %6, %8 : i32 loc(#loc8) + %12 = arith.muli %8, %c8_i32 : i32 loc(#loc9) + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10) + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc10) + %15 = arith.subi %11, %4 : i32 loc(#loc11) + %16 = arith.ceildivsi %15, %c132_i32 : i32 loc(#loc11) + %17 = arith.extsi %10 : i32 to i64 loc(#loc11) + %18 = arith.maxsi %17, %c1_i64 : i64 loc(#loc11) + %19 = arith.extsi %16 : i32 to i64 loc(#loc11) + %20 = arith.muli %19, %18 : i64 loc(#loc11) + %21 = arith.subi %4, %c132_i32 : i32 loc(#loc11) + %22 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc12) + %23 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc13) + %24 = arith.cmpi sgt, %20, %c0_i64 : i64 loc(#loc11) + %25 = arith.remsi %c0_i64, %18 : i64 loc(#loc11) + %26 = arith.cmpi eq, %25, %c0_i64 : i64 loc(#loc11) + %27 = arith.select %26, %4, %21 : i32 loc(#loc11) + %28 = arith.cmpi ne, %25, %c0_i64 : i64 loc(#loc11) + %29:4 = scf.if %26 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %110 = arith.divsi %4, %12 : i32 loc(#loc14) + %111 = arith.muli %110, %c8_i32 : i32 loc(#loc15) + %112 = arith.subi %6, %111 : i32 loc(#loc16) + %113 = arith.minsi %112, %c8_i32 : i32 loc(#loc17) + %114 = arith.remsi %4, %113 : i32 loc(#loc18) + %115 = arith.addi %111, %114 : i32 loc(#loc19) + %116 = arith.remsi %4, %12 : i32 loc(#loc20) + %117 = arith.divsi %116, %113 : i32 loc(#loc21) + %118 = arith.muli %115, %c128_i32 : i32 loc(#loc22) + %119 = arith.muli %117, %c256_i32 : i32 loc(#loc23) + %120 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) + %121 = tt.splat %118 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %122 = arith.addi %121, %120 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %123 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) + %124 = tt.splat %119 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %125 = arith.addi %124, %123 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %126 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %127 = arith.cmpi slt, %122, %126 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %128 = arith.select %127, %122, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) + %129 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %130 = arith.cmpi slt, %125, %129 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %131 = arith.select %130, %125, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) + scf.yield %118, %119, %128, %131 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) + } else { + scf.yield %3, %3, %1, %0 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) + } loc(#loc11) + %30 = tt.expand_dims %29#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) + %31 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc33) + %32 = arith.muli %30, %31 : tensor<128x1xi32, #blocked1> loc(#loc33) + %33 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) + %34 = tt.broadcast %32 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %35 = tt.broadcast %33 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %36 = arith.addi %34, %35 : tensor<128x64xi32, #blocked1> loc(#loc35) + %37 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc36) + %38 = tt.addptr %37, %36 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) + %39 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) + %40 = tt.expand_dims %29#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) + %41 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc39) + %42 = arith.muli %40, %41 : tensor<1x256xi32, #blocked> loc(#loc39) + %43 = tt.broadcast %39 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %44 = tt.broadcast %42 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %45 = arith.addi %43, %44 : tensor<64x256xi32, #blocked> loc(#loc40) + %46 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc41) + %47 = tt.addptr %46, %45 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) + %48 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) + %49 = arith.cmpi slt, %33, %48 : tensor<1x64xi32, #blocked1> loc(#loc42) + %50 = tt.broadcast %49 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) + %51 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %52 = tt.splat %24 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) + %53 = arith.andi %52, %50 : tensor<128x64xi1, #blocked1> loc(#loc11) + %54 = ttg.async_copy_global_to_local %38, %51 mask %53 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %55 = ttg.async_commit_group %54 loc(#loc12) + %56 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) + %57 = arith.cmpi slt, %39, %56 : tensor<64x1xi32, #blocked> loc(#loc43) + %58 = tt.broadcast %57 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) + %59 = ttg.memdesc_subview %23[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %60 = tt.splat %24 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) + %61 = arith.andi %60, %58 : tensor<64x256xi1, #blocked> loc(#loc11) + %62 = ttg.async_copy_global_to_local %47, %59 mask %61 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %63 = ttg.async_commit_group %62 loc(#loc13) + %64 = arith.cmpi sgt, %20, %c1_i64 : i64 loc(#loc11) + %65 = arith.addi %25, %c1_i64 : i64 loc(#loc11) + %66 = arith.remsi %65, %18 : i64 loc(#loc11) + %67 = arith.cmpi eq, %66, %c0_i64 : i64 loc(#loc11) + %68 = arith.cmpi ne, %66, %c0_i64 : i64 loc(#loc11) + %69 = arith.extui %68 : i1 to i32 loc(#loc11) + %70:5 = scf.if %67 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %110 = arith.addi %27, %c132_i32 : i32 loc(#loc11) + %111 = arith.divsi %110, %12 : i32 loc(#loc14) + %112 = arith.muli %111, %c8_i32 : i32 loc(#loc15) + %113 = arith.subi %6, %112 : i32 loc(#loc16) + %114 = arith.minsi %113, %c8_i32 : i32 loc(#loc17) + %115 = arith.remsi %110, %114 : i32 loc(#loc18) + %116 = arith.addi %112, %115 : i32 loc(#loc19) + %117 = arith.remsi %110, %12 : i32 loc(#loc20) + %118 = arith.divsi %117, %114 : i32 loc(#loc21) + %119 = arith.muli %116, %c128_i32 : i32 loc(#loc22) + %120 = arith.muli %118, %c256_i32 : i32 loc(#loc23) + %121 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) + %122 = tt.splat %119 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %123 = arith.addi %122, %121 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %124 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) + %125 = tt.splat %120 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %126 = arith.addi %125, %124 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %127 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %128 = arith.cmpi slt, %123, %127 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %129 = arith.select %128, %123, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) + %130 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %131 = arith.cmpi slt, %126, %130 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %132 = arith.select %131, %126, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) + scf.yield %119, %120, %129, %132, %110 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) + } else { + scf.yield %29#0, %29#1, %29#2, %29#3, %27 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) + } loc(#loc11) + %71 = arith.muli %69, %c64_i32 : i32 loc(#loc44) + %72 = tt.splat %71 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %73 = tt.splat %71 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %74 = arith.addi %72, %13 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %75 = arith.addi %73, %14 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %76 = tt.expand_dims %70#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) + %77 = arith.muli %76, %31 : tensor<128x1xi32, #blocked1> loc(#loc33) + %78 = tt.expand_dims %74 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) + %79 = tt.broadcast %77 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %80 = tt.broadcast %78 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %81 = arith.addi %79, %80 : tensor<128x64xi32, #blocked1> loc(#loc35) + %82 = tt.addptr %37, %81 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) + %83 = tt.expand_dims %75 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) + %84 = tt.expand_dims %70#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) + %85 = arith.muli %84, %41 : tensor<1x256xi32, #blocked> loc(#loc39) + %86 = tt.broadcast %83 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %87 = tt.broadcast %85 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %88 = arith.addi %86, %87 : tensor<64x256xi32, #blocked> loc(#loc40) + %89 = tt.addptr %46, %88 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) + %90 = arith.subi %arg5, %71 : i32 loc(#loc46) + %91 = tt.splat %90 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) + %92 = arith.cmpi slt, %33, %91 : tensor<1x64xi32, #blocked1> loc(#loc42) + %93 = tt.broadcast %92 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) + %94 = ttg.memdesc_subview %22[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %95 = tt.splat %64 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) + %96 = arith.andi %95, %93 : tensor<128x64xi1, #blocked1> loc(#loc11) + %97 = ttg.async_copy_global_to_local %82, %94 mask %96 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %98 = ttg.async_commit_group %97 loc(#loc12) + %99 = tt.splat %90 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) + %100 = arith.cmpi slt, %39, %99 : tensor<64x1xi32, #blocked> loc(#loc43) + %101 = tt.broadcast %100 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) + %102 = ttg.memdesc_subview %23[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %103 = tt.splat %64 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) + %104 = arith.andi %103, %101 : tensor<64x256xi1, #blocked> loc(#loc11) + %105 = ttg.async_copy_global_to_local %89, %102 mask %104 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %106 = ttg.async_commit_group %105 loc(#loc13) + %107:20 = scf.for %arg9 = %c0_i64 to %20 step %c1_i64 iter_args(%arg10 = %66, %arg11 = %70#4, %arg12 = %2, %arg13 = %70#0, %arg14 = %70#1, %arg15 = %70#2, %arg16 = %70#3, %arg17 = %c1_i32, %arg18 = %c-1_i32, %arg19 = %69, %arg20 = %63, %arg21 = %106, %arg22 = %28, %arg23 = %68, %arg24 = %25, %arg25 = %66, %arg26 = %29#0, %arg27 = %70#0, %arg28 = %29#1, %arg29 = %70#1) -> ( + i64, i32, + tensor<128x256xf32, #mma>, + i32, i32, + tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, + tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, + i32, i32, i32, !ttg.async.token, !ttg.async.token, i1, i1, i64, i64, i32, i32, i32, i32) : i64 { + %110 = arith.subi %20, %c2_i64 : i64 loc(#loc11) + %111 = arith.cmpi slt, %arg9, %110 : i64 loc(#loc11) + %112 = arith.addi %arg19, %c1_i32 : i32 loc(#loc11) + %113 = arith.addi %arg10, %c1_i64 : i64 loc(#loc11) + %114 = arith.remsi %113, %18 : i64 loc(#loc11) + %115 = arith.cmpi eq, %114, %c0_i64 : i64 loc(#loc11) + %116 = arith.select %115, %c0_i32, %112 : i32 loc(#loc11) + %117 = arith.cmpi ne, %114, %c0_i64 : i64 loc(#loc11) + %118:5 = scf.if %115 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %168 = arith.addi %arg11, %c132_i32 : i32 loc(#loc11) + %169 = arith.divsi %168, %12 : i32 loc(#loc14) + %170 = arith.muli %169, %c8_i32 : i32 loc(#loc15) + %171 = arith.subi %6, %170 : i32 loc(#loc16) + %172 = arith.minsi %171, %c8_i32 : i32 loc(#loc17) + %173 = arith.remsi %168, %172 : i32 loc(#loc18) + %174 = arith.addi %170, %173 : i32 loc(#loc19) + %175 = arith.remsi %168, %12 : i32 loc(#loc20) + %176 = arith.divsi %175, %172 : i32 loc(#loc21) + %177 = arith.muli %174, %c128_i32 : i32 loc(#loc22) + %178 = arith.muli %176, %c256_i32 : i32 loc(#loc23) + %179 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) + %180 = tt.splat %177 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %181 = arith.addi %180, %179 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %182 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) + %183 = tt.splat %178 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %184 = arith.addi %183, %182 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %185 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %186 = arith.cmpi slt, %181, %185 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %187 = arith.select %186, %181, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) + %188 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %189 = arith.cmpi slt, %184, %188 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %190 = arith.select %189, %184, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) + scf.yield %177, %178, %187, %190, %168 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) + } else { + scf.yield %arg13, %arg14, %arg15, %arg16, %arg11 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) + } loc(#loc11) + %119 = arith.addi %arg18, %c1_i32 : i32 loc(#loc11) + %120 = arith.cmpi slt, %119, %c3_i32 : i32 loc(#loc11) + %121 = arith.select %120, %119, %c0_i32 : i32 loc(#loc11) + %122 = ttg.memdesc_subview %22[%121, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %123 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc12) + %124 = ttg.memdesc_subview %23[%121, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %125 = ttng.warp_group_dot %122, %124, %arg12, %arg22 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc47) + %126:3 = ttng.warp_group_dot_wait %125, %122, %124 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) + %127 = arith.addi %arg17, %c1_i32 : i32 loc(#loc11) + %128 = arith.cmpi slt, %127, %c3_i32 : i32 loc(#loc11) + %129 = arith.select %128, %127, %c0_i32 : i32 loc(#loc11) + %130 = arith.muli %116, %c64_i32 : i32 loc(#loc44) + %131 = tt.splat %130 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %132 = tt.splat %130 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %133 = arith.addi %131, %13 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %134 = arith.addi %132, %14 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %135 = tt.expand_dims %118#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) + %136 = arith.muli %135, %31 : tensor<128x1xi32, #blocked1> loc(#loc33) + %137 = tt.expand_dims %133 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) + %138 = tt.broadcast %136 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %139 = tt.broadcast %137 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %140 = arith.addi %138, %139 : tensor<128x64xi32, #blocked1> loc(#loc35) + %141 = tt.addptr %37, %140 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) + %142 = tt.expand_dims %134 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) + %143 = tt.expand_dims %118#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) + %144 = arith.muli %143, %41 : tensor<1x256xi32, #blocked> loc(#loc39) + %145 = tt.broadcast %142 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %146 = tt.broadcast %144 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %147 = arith.addi %145, %146 : tensor<64x256xi32, #blocked> loc(#loc40) + %148 = tt.addptr %46, %147 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) + %149 = arith.subi %arg5, %130 : i32 loc(#loc46) + %150 = tt.splat %149 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) + %151 = arith.cmpi slt, %33, %150 : tensor<1x64xi32, #blocked1> loc(#loc42) + %152 = tt.broadcast %151 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) + %153 = ttg.memdesc_subview %22[%129, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %154 = tt.splat %111 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) + %155 = arith.andi %154, %152 : tensor<128x64xi1, #blocked1> loc(#loc11) + %156 = ttg.async_copy_global_to_local %141, %153 mask %155 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %157 = ttg.async_commit_group %156 loc(#loc12) + %158 = tt.splat %149 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) + %159 = arith.cmpi slt, %39, %158 : tensor<64x1xi32, #blocked> loc(#loc43) + %160 = tt.broadcast %159 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) + %161 = ttg.memdesc_subview %23[%129, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %162 = tt.splat %111 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) + %163 = arith.andi %162, %160 : tensor<64x256xi1, #blocked> loc(#loc11) + %164 = ttg.async_copy_global_to_local %148, %161 mask %163 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %165 = ttg.async_commit_group %164 loc(#loc13) + %166 = arith.subi %18, %c1_i64 : i64 loc(#loc11) + %167 = arith.cmpi eq, %arg24, %166 : i64 loc(#loc11) + scf.if %167 { + %168:3 = ttng.warp_group_dot_wait %126#0, %122, %124 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) + %169 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc24) + %170 = tt.splat %arg26 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) + %171 = arith.addi %170, %169 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) + %172 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc26) + %173 = tt.splat %arg28 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) + %174 = arith.addi %173, %172 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) + %175 = tt.expand_dims %171 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc48) + %176 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc49) + %177 = arith.muli %176, %175 : tensor<128x1xi32, #blocked2> loc(#loc49) + %178 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc50) + %179 = tt.addptr %178, %177 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc50) + %180 = tt.expand_dims %174 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc51) + %181 = tt.broadcast %179 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc52) + %182 = tt.broadcast %180 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc52) + %183 = tt.addptr %181, %182 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc52) + %184 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc53) + %185 = arith.cmpi slt, %175, %184 : tensor<128x1xi32, #blocked2> loc(#loc53) + %186 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc54) + %187 = arith.cmpi slt, %180, %186 : tensor<1x256xi32, #blocked2> loc(#loc54) + %188 = tt.broadcast %185 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) + %189 = tt.broadcast %187 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) + %190 = arith.andi %188, %189 : tensor<128x256xi1, #blocked2> loc(#loc55) + %191 = arith.truncf %168#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc56) + %192 = ttg.convert_layout %191 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc57) + tt.store %183, %192, %190 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc57) + } loc(#loc11) + scf.yield %114, %118#4, %126#0, %118#0, %118#1, %118#2, %118#3, %129, %121, %116, %arg21, %165, %arg23, %117, %arg25, %114, %arg27, %118#0, %arg29, %118#1 : i64, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, !ttg.async.token, !ttg.async.token, i1, i1, i64, i64, i32, i32, i32, i32 loc(#loc11) + } loc(#loc11) + %108 = ttng.warp_group_dot_wait %107#2 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc11) + %109 = ttg.async_wait {num = 0 : i32} loc(#loc11) + ttg.local_dealloc %22 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc11) + ttg.local_dealloc %23 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc11) + tt.return loc(#loc58) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":242:30) +#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) +#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":243:27) +#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":244:27) +#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":245:25) +#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":246:28) +#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":247:38) +#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":249:35) +#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":251:47) +#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":273:24) +#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":274:24) +#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":252:30) +#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":253:33) +#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":254:39) +#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":254:52) +#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":255:41) +#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":255:31) +#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":256:27) +#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":256:48) +#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":258:26) +#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":259:26) +#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":260:41) +#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":260:28) +#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":261:41) +#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":261:28) +#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":262:37) +#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":262:49) +#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":263:37) +#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":263:49) +#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:38) +#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:49) +#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:68) +#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:61) +#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:30) +#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:37) +#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:68) +#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:79) +#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:60) +#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:30) +#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":273:64) +#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":274:64) +#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":269:26) +#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":269:41) +#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":273:68) +#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":275:39) +#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:45) +#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:37) +#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:25) +#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:76) +#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:56) +#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":280:37) +#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":280:62) +#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":280:43) +#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":284:31) +#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":285:25) +#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":251:4) +#loc59 = loc(callsite(#loc3 at #loc4)) +#loc60 = loc(callsite(#loc5 at #loc4)) +#loc61 = loc(callsite(#loc3 at #loc6)) +#loc62 = loc(callsite(#loc5 at #loc6)) +#loc63 = loc(callsite(#loc3 at #loc7)) +#loc64 = loc(callsite(#loc5 at #loc7)) + diff --git a/orig.mlir b/orig.mlir new file mode 100644 index 000000000000..9c63eb7a9f7d --- /dev/null +++ b/orig.mlir @@ -0,0 +1,386 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c3_i32 = arith.constant 3 : i32 loc(#loc1) + %false = arith.constant false loc(#loc1) + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %c-1_i32 = arith.constant -1 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c132_i32 = arith.constant 132 : i32 loc(#loc1) + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %c255_i32 = arith.constant 255 : i32 loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc80) + %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc81) + %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc82) + %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc83) + %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc84) + %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc85) + %7 = arith.muli %2, %4 : i32 loc(#loc8) + %8 = arith.divsi %7, %c132_i32 : i32 loc(#loc9) + %9 = arith.remsi %7, %c132_i32 : i32 loc(#loc10) + %10 = arith.cmpi slt, %0, %9 : i32 loc(#loc11) + %11 = scf.if %10 -> (i32) { + %122 = arith.addi %8, %c1_i32 : i32 loc(#loc13) + scf.yield %122 : i32 loc(#loc13) + } else { + scf.yield %8 : i32 loc(#loc1) + } loc(#loc12) + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc14) + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc14) + %14 = arith.muli %4, %c8_i32 : i32 loc(#loc15) + %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc16) + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc16) + %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc17) + %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc17) + %19 = arith.muli %6, %11 : i32 loc(#loc18) + %20 = arith.subi %6, %c1_i32 : i32 loc(#loc19) + %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc20) + %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc21) + %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc22) + %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc23) + %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc24) + %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc25) + %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc26) + %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc27) + %29 = arith.cmpi sgt, %19, %c0_i32 : i32 loc(#loc28) + %30 = arith.divsi %0, %14 : i32 loc(#loc29) + %31 = arith.muli %30, %c8_i32 : i32 loc(#loc30) + %32 = arith.subi %2, %31 : i32 loc(#loc31) + %33 = arith.minsi %32, %c8_i32 : i32 loc(#loc32) + %34 = arith.remsi %0, %33 : i32 loc(#loc33) + %35 = arith.addi %31, %34 : i32 loc(#loc34) + %36 = arith.remsi %0, %14 : i32 loc(#loc35) + %37 = arith.divsi %36, %33 : i32 loc(#loc36) + %38 = arith.muli %35, %c128_i32 : i32 loc(#loc37) + %39 = arith.muli %37, %c256_i32 : i32 loc(#loc38) + %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) + %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) + %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) + %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> loc(#loc46) + %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) + %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) + %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> loc(#loc22) + %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> loc(#loc48) + %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) + %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) + %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> loc(#loc49) + %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) + %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) + %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> loc(#loc28) + %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %69 = ttg.async_commit_group %68 loc(#loc26) + %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) + %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> loc(#loc50) + %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) + %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) + %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> loc(#loc28) + %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %77 = ttg.async_commit_group %76 loc(#loc27) + %78 = arith.cmpi sgt, %19, %c1_i32 : i32 loc(#loc28) + %79 = arith.cmpi ne, %20, %c0_i32 : i32 loc(#loc86) + %80 = arith.extui %79 : i1 to i32 loc(#loc51) + %81 = arith.cmpi eq, %80, %c0_i32 : i32 loc(#loc53) + %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %122 = arith.addi %0, %c132_i32 : i32 loc(#loc55) + %123 = arith.divsi %122, %14 : i32 loc(#loc29) + %124 = arith.muli %123, %c8_i32 : i32 loc(#loc30) + %125 = arith.subi %2, %124 : i32 loc(#loc31) + %126 = arith.minsi %125, %c8_i32 : i32 loc(#loc32) + %127 = arith.remsi %122, %126 : i32 loc(#loc33) + %128 = arith.addi %124, %127 : i32 loc(#loc34) + %129 = arith.remsi %122, %14 : i32 loc(#loc35) + %130 = arith.divsi %129, %126 : i32 loc(#loc36) + %131 = arith.muli %128, %c128_i32 : i32 loc(#loc37) + %132 = arith.muli %130, %c256_i32 : i32 loc(#loc38) + %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) + %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + } else { + scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + } loc(#loc54) + %83 = arith.muli %80, %c64_i32 : i32 loc(#loc56) + %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) + %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) + %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) + %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> loc(#loc46) + %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) + %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) + %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) + %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> loc(#loc22) + %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> loc(#loc48) + %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) + %102 = arith.subi %arg5, %83 : i32 loc(#loc60) + %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) + %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> loc(#loc49) + %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) + %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) + %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> loc(#loc28) + %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %110 = ttg.async_commit_group %109 loc(#loc26) + %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) + %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> loc(#loc50) + %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) + %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) + %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> loc(#loc28) + %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %118 = ttg.async_commit_group %117 loc(#loc27) + %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> ( + i32, i32, i32, i32, + tensor<128x256xf32, #mma>, + tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, + tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, + i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { + %122 = arith.subi %19, %c2_i32 : i32 loc(#loc28) + %123 = arith.cmpi slt, %arg9, %122 : i32 loc(#loc28) + %124 = arith.cmpi eq, %arg10, %20 : i32 loc(#loc52) + %125 = arith.addi %arg10, %c1_i32 : i32 loc(#loc61) + %126 = arith.select %124, %c0_i32, %125 : i32 loc(#loc51) + %127 = arith.cmpi eq, %126, %c0_i32 : i32 loc(#loc53) + %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %178 = arith.addi %arg11, %c132_i32 : i32 loc(#loc55) + %179 = arith.divsi %178, %14 : i32 loc(#loc29) + %180 = arith.muli %179, %c8_i32 : i32 loc(#loc30) + %181 = arith.subi %2, %180 : i32 loc(#loc31) + %182 = arith.minsi %181, %c8_i32 : i32 loc(#loc32) + %183 = arith.remsi %178, %182 : i32 loc(#loc33) + %184 = arith.addi %180, %183 : i32 loc(#loc34) + %185 = arith.remsi %178, %14 : i32 loc(#loc35) + %186 = arith.divsi %185, %182 : i32 loc(#loc36) + %187 = arith.muli %184, %c128_i32 : i32 loc(#loc37) + %188 = arith.muli %186, %c256_i32 : i32 loc(#loc38) + %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) + %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) + %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) + %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) + %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) + %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + } else { + scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + } loc(#loc54) + %129 = arith.addi %arg19, %c1_i32 : i32 loc(#loc28) + %130 = arith.cmpi slt, %129, %c3_i32 : i32 loc(#loc28) + %131 = arith.select %130, %129, %c0_i32 : i32 loc(#loc28) + %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %133 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc26) + %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc62) + %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) + %137 = arith.addi %arg18, %c1_i32 : i32 loc(#loc28) + %138 = arith.cmpi slt, %137, %c3_i32 : i32 loc(#loc28) + %139 = arith.select %138, %137, %c0_i32 : i32 loc(#loc28) + %140 = arith.muli %126, %c64_i32 : i32 loc(#loc56) + %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) + %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) + %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) + %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) + %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> loc(#loc46) + %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) + %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) + %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) + %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> loc(#loc22) + %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) + %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> loc(#loc48) + %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) + %159 = arith.subi %arg5, %140 : i32 loc(#loc60) + %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) + %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> loc(#loc49) + %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) + %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) + %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> loc(#loc28) + %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) + %167 = ttg.async_commit_group %166 loc(#loc26) + %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) + %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> loc(#loc50) + %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) + %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) + %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> loc(#loc28) + %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) + %175 = ttg.async_commit_group %174 loc(#loc27) + %176 = arith.cmpi eq, %arg22, %20 : i32 loc(#loc63) + %177 = arith.cmpi ne, %arg22, %20 : i32 loc(#loc87) + scf.if %176 { + %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) + %179 = arith.muli %arg24, %c128_i32 : i32 loc(#loc65) + %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) + %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) + %182 = arith.muli %arg26, %c256_i32 : i32 loc(#loc67) + %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) + %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) + %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc69) + %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc70) + %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> loc(#loc70) + %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc71) + %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc71) + %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc72) + %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc73) + %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc73) + %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc73) + %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc74) + %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> loc(#loc74) + %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc75) + %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> loc(#loc75) + %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) + %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) + %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> loc(#loc76) + %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc77) + %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc78) + tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc78) + } loc(#loc64) + scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 loc(#loc28) + } loc(#loc28) + %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc28) + %121 = ttg.async_wait {num = 0 : i32} loc(#loc28) + ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc28) + ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc28) + tt.return loc(#loc79) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) +#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) +#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) +#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) +#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) +#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) +#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":172:32) +#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:31) +#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:19) +#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:7) +#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:24) +#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:35) +#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":181:38) +#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:27) +#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:27) +#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:32) +#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:38) +#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:45) +#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:26) +#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:75) +#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:26) +#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:49) +#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:49) +#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:20) +#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:20) +#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:22) +#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:34) +#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) +#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:43) +#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:56) +#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:45) +#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:35) +#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:31) +#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:52) +#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":200:30) +#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":201:30) +#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":202:32) +#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:32) +#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:41) +#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:53) +#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:41) +#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:53) +#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:34) +#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:57) +#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:64) +#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:56) +#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:60) +#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:60) +#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:44) +#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:28) +#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:17) +#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:11) +#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:23) +#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:22) +#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:37) +#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:64) +#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:33) +#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:64) +#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:49) +#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:35) +#loc63 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:17) +#loc64 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:11) +#loc65 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:30) +#loc66 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:45) +#loc67 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:30) +#loc68 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:45) +#loc69 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:49) +#loc70 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:41) +#loc71 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:29) +#loc72 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:80) +#loc73 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:60) +#loc74 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:41) +#loc75 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:66) +#loc76 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:47) +#loc77 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":224:35) +#loc78 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":225:29) +#loc79 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:4) +#loc80 = loc(callsite(#loc3 at #loc4)) +#loc81 = loc(callsite(#loc5 at #loc4)) +#loc82 = loc(callsite(#loc3 at #loc6)) +#loc83 = loc(callsite(#loc5 at #loc6)) +#loc84 = loc(callsite(#loc3 at #loc7)) +#loc85 = loc(callsite(#loc5 at #loc7)) +#loc86 = loc(fused[#loc51, #loc52]) +#loc87 = loc(fused[#loc64, #loc63]) + diff --git a/orig2.mlir b/orig2.mlir new file mode 100644 index 000000000000..63cc3d385e0b --- /dev/null +++ b/orig2.mlir @@ -0,0 +1,293 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c3_i32 = arith.constant 3 : i32 + %false = arith.constant false + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.divsi %7, %c132_i32 : i32 + %9 = arith.remsi %7, %c132_i32 : i32 + %10 = arith.cmpi slt, %0, %9 : i32 + %11 = scf.if %10 -> (i32) { + %122 = arith.addi %8, %c1_i32 : i32 + scf.yield %122 : i32 + } else { + scf.yield %8 : i32 + } + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %14 = arith.muli %4, %c8_i32 : i32 + %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %19 = arith.muli %6, %11 : i32 + %20 = arith.subi %6, %c1_i32 : i32 + %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %29 = arith.cmpi sgt, %19, %c0_i32 : i32 + %30 = arith.divsi %0, %14 : i32 + %31 = arith.muli %30, %c8_i32 : i32 + %32 = arith.subi %2, %31 : i32 + %33 = arith.minsi %32, %c8_i32 : i32 + %34 = arith.remsi %0, %33 : i32 + %35 = arith.addi %31, %34 : i32 + %36 = arith.remsi %0, %14 : i32 + %37 = arith.divsi %36, %33 : i32 + %38 = arith.muli %35, %c128_i32 : i32 + %39 = arith.muli %37, %c256_i32 : i32 + %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> + %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> + %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> + %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> + %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> + %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> + %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> + %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %69 = ttg.async_commit_group %68 + %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> + %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> + %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> + %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> + %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %77 = ttg.async_commit_group %76 + %78 = arith.cmpi sgt, %19, %c1_i32 : i32 + %79 = arith.cmpi ne, %20, %c0_i32 : i32 + %80 = arith.extui %79 : i1 to i32 + %81 = arith.cmpi eq, %80, %c0_i32 : i32 + %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %122 = arith.addi %0, %c132_i32 : i32 + %123 = arith.divsi %122, %14 : i32 + %124 = arith.muli %123, %c8_i32 : i32 + %125 = arith.subi %2, %124 : i32 + %126 = arith.minsi %125, %c8_i32 : i32 + %127 = arith.remsi %122, %126 : i32 + %128 = arith.addi %124, %127 : i32 + %129 = arith.remsi %122, %14 : i32 + %130 = arith.divsi %129, %126 : i32 + %131 = arith.muli %128, %c128_i32 : i32 + %132 = arith.muli %130, %c256_i32 : i32 + %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } else { + scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } + %83 = arith.muli %80, %c64_i32 : i32 + %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> + %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> + %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> + %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> + %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %102 = arith.subi %arg5, %83 : i32 + %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> + %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> + %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> + %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> + %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %110 = ttg.async_commit_group %109 + %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> + %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> + %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> + %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> + %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %118 = ttg.async_commit_group %117 + %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { + %122 = arith.subi %19, %c2_i32 : i32 + %123 = arith.cmpi slt, %arg9, %122 : i32 + %124 = arith.cmpi eq, %arg10, %20 : i32 + %125 = arith.addi %arg10, %c1_i32 : i32 + %126 = arith.select %124, %c0_i32, %125 : i32 + %127 = arith.cmpi eq, %126, %c0_i32 : i32 + %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %178 = arith.addi %arg11, %c132_i32 : i32 + %179 = arith.divsi %178, %14 : i32 + %180 = arith.muli %179, %c8_i32 : i32 + %181 = arith.subi %2, %180 : i32 + %182 = arith.minsi %181, %c8_i32 : i32 + %183 = arith.remsi %178, %182 : i32 + %184 = arith.addi %180, %183 : i32 + %185 = arith.remsi %178, %14 : i32 + %186 = arith.divsi %185, %182 : i32 + %187 = arith.muli %184, %c128_i32 : i32 + %188 = arith.muli %186, %c256_i32 : i32 + %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } else { + scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } + %129 = arith.addi %arg19, %c1_i32 : i32 + %130 = arith.cmpi slt, %129, %c3_i32 : i32 + %131 = arith.select %130, %129, %c0_i32 : i32 + %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %133 = ttg.async_wait %arg20 {num = 2 : i32} + %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> + %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %137 = arith.addi %arg18, %c1_i32 : i32 + %138 = arith.cmpi slt, %137, %c3_i32 : i32 + %139 = arith.select %138, %137, %c0_i32 : i32 + %140 = arith.muli %126, %c64_i32 : i32 + %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> + %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> + %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> + %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> + %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %159 = arith.subi %arg5, %140 : i32 + %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> + %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> + %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> + %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> + %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %167 = ttg.async_commit_group %166 + %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> + %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> + %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> + %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> + %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %175 = ttg.async_commit_group %174 + %176 = arith.cmpi eq, %arg22, %20 : i32 + %177 = arith.cmpi ne, %arg22, %20 : i32 + scf.if %176 { + %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %179 = arith.muli %arg24, %c128_i32 : i32 + %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %182 = arith.muli %arg26, %c256_i32 : i32 + %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> + %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> + %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> + %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> + %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> + } + scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 + } + %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> + %121 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + tt.return + } +} + diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 94067cd6b0f2..5499e035915b 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -26,6 +26,7 @@ import triton.language as tl import triton.tools.experimental_descriptor import triton.profiler as proton +import pathlib from contextlib import contextmanager from typing import Optional @@ -151,6 +152,120 @@ def matmul(a, b): return c +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent_fused(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + ): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if (c_ptr.dtype.element_ty == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_persistent_fused(a, b): + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + + matmul_kernel_persistent_fused[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + NUM_SMS=NUM_SMS, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + ) + return c + + @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # M, N, K, # @@ -209,6 +324,306 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # tl.store(c_ptrs, c, mask=c_mask) +matmul_kernel_persistent_ttgir = """ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c3_i32 = arith.constant 3 : i32 + %false = arith.constant false + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.divsi %7, %c132_i32 : i32 + %9 = arith.remsi %7, %c132_i32 : i32 + %10 = arith.cmpi slt, %0, %9 : i32 + %11 = scf.if %10 -> (i32) { + %122 = arith.addi %8, %c1_i32 : i32 + scf.yield %122 : i32 + } else { + scf.yield %8 : i32 + } + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %14 = arith.muli %4, %c8_i32 : i32 + %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %19 = arith.muli %6, %11 : i32 + %20 = arith.subi %6, %c1_i32 : i32 + %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %29 = arith.cmpi sgt, %19, %c0_i32 : i32 + %30 = arith.divsi %0, %14 : i32 + %31 = arith.muli %30, %c8_i32 : i32 + %32 = arith.subi %2, %31 : i32 + %33 = arith.minsi %32, %c8_i32 : i32 + %34 = arith.remsi %0, %33 : i32 + %35 = arith.addi %31, %34 : i32 + %36 = arith.remsi %0, %14 : i32 + %37 = arith.divsi %36, %33 : i32 + %38 = arith.muli %35, %c128_i32 : i32 + %39 = arith.muli %37, %c256_i32 : i32 + %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> + %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> + %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> + %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> + %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> + %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> + %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> + %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %69 = ttg.async_commit_group %68 + %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> + %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> + %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> + %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> + %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %77 = ttg.async_commit_group %76 + %78 = arith.cmpi sgt, %19, %c1_i32 : i32 + %79 = arith.cmpi ne, %20, %c0_i32 : i32 + %80 = arith.extui %79 : i1 to i32 + %81 = arith.cmpi eq, %80, %c0_i32 : i32 + %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %122 = arith.addi %0, %c132_i32 : i32 + %123 = arith.divsi %122, %14 : i32 + %124 = arith.muli %123, %c8_i32 : i32 + %125 = arith.subi %2, %124 : i32 + %126 = arith.minsi %125, %c8_i32 : i32 + %127 = arith.remsi %122, %126 : i32 + %128 = arith.addi %124, %127 : i32 + %129 = arith.remsi %122, %14 : i32 + %130 = arith.divsi %129, %126 : i32 + %131 = arith.muli %128, %c128_i32 : i32 + %132 = arith.muli %130, %c256_i32 : i32 + %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } else { + scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } + %83 = arith.muli %80, %c64_i32 : i32 + %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> + %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> + %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> + %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> + %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %102 = arith.subi %arg5, %83 : i32 + %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> + %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> + %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> + %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> + %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %110 = ttg.async_commit_group %109 + %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> + %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> + %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> + %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> + %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %118 = ttg.async_commit_group %117 + %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { + %122 = arith.subi %19, %c2_i32 : i32 + %123 = arith.cmpi slt, %arg9, %122 : i32 + %124 = arith.cmpi eq, %arg10, %20 : i32 + %125 = arith.addi %arg10, %c1_i32 : i32 + %126 = arith.select %124, %c0_i32, %125 : i32 + %127 = arith.cmpi eq, %126, %c0_i32 : i32 + %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %178 = arith.addi %arg11, %c132_i32 : i32 + %179 = arith.divsi %178, %14 : i32 + %180 = arith.muli %179, %c8_i32 : i32 + %181 = arith.subi %2, %180 : i32 + %182 = arith.minsi %181, %c8_i32 : i32 + %183 = arith.remsi %178, %182 : i32 + %184 = arith.addi %180, %183 : i32 + %185 = arith.remsi %178, %14 : i32 + %186 = arith.divsi %185, %182 : i32 + %187 = arith.muli %184, %c128_i32 : i32 + %188 = arith.muli %186, %c256_i32 : i32 + %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } else { + scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } + %129 = arith.addi %arg19, %c1_i32 : i32 + %130 = arith.cmpi slt, %129, %c3_i32 : i32 + %131 = arith.select %130, %129, %c0_i32 : i32 + %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %133 = ttg.async_wait %arg20 {num = 2 : i32} + %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> + %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %137 = arith.addi %arg18, %c1_i32 : i32 + %138 = arith.cmpi slt, %137, %c3_i32 : i32 + %139 = arith.select %138, %137, %c0_i32 : i32 + %140 = arith.muli %126, %c64_i32 : i32 + %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> + %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> + %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> + %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> + %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %159 = arith.subi %arg5, %140 : i32 + %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> + %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> + %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> + %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> + %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %167 = ttg.async_commit_group %166 + %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> + %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> + %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> + %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> + %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %175 = ttg.async_commit_group %174 + %176 = arith.cmpi eq, %arg22, %20 : i32 + %177 = arith.cmpi ne, %arg22, %20 : i32 + scf.if %176 { + %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %179 = arith.muli %arg24, %c128_i32 : i32 + %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %182 = arith.muli %arg26, %c256_i32 : i32 + %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> + %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> + %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> + %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> + %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> + } + scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 + } + %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> + %121 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + tt.return + } +} +""" + +file = pathlib.Path("matmul_kernel_persistent.ttgir") +file.write_text(matmul_kernel_persistent_ttgir) +matmul_kernel_persistent_precompiled = triton.compile(str(file)) + + def matmul_persistent(a, b): configs = { torch.float8_e4m3fn: { @@ -230,6 +645,24 @@ def matmul_persistent(a, b): c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + + #assert a.stride(1) == 1 and b.stride(0) == 1 and c.stride(1) == 1 + #bytes_per_elem = a.element_size() + #flops_str = f"flops{bytes_per_elem * 8}" + #with proton.scope(f"precompiled [M={M}, N={N}, K={K}]", + # {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): + # matmul_kernel_persistent_precompiled[(grid(configs[torch.float16])[0], 1, 1)]( + # a, + # b, + # c, # + # M, + # N, + # K, # + # a.stride(0), + # b.stride(1), # + # c.stride(0), + # ) + matmul_kernel_persistent[grid]( a, b, c, # M, N, K, # @@ -244,21 +677,23 @@ def matmul_persistent(a, b): num_stages=configs[dtype]["num_stages"], # num_warps=configs[dtype]["num_warps"], # ) - kernel = matmul_kernel_persistent.warmup( - a, b, c, # - M, N, K, # - a.stride(0), a.stride(1), # - b.stride(0), b.stride(1), # - c.stride(0), c.stride(1), # - BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # - BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # - BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # - GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # - NUM_SMS=NUM_SMS, # - num_stages=configs[dtype]["num_stages"], # - num_warps=configs[dtype]["num_warps"], # - grid=grid - ) + + #kernel = matmul_kernel_persistent.warmup( + # a, b, c, # + # M, N, K, # + # a.stride(0), a.stride(1), # + # b.stride(0), b.stride(1), # + # c.stride(0), c.stride(1), # + # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + # BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + # BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + # GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + # NUM_SMS=NUM_SMS, # + # num_stages=configs[dtype]["num_stages"], # + # num_warps=configs[dtype]["num_warps"], # + # grid=grid + #) + #print(kernel.asm["ttgir"]) return c @@ -525,15 +960,16 @@ def bench(K, dtype, reps=1000, warmup_reps=10000): b = b.T.contiguous() - if cublas is not None: - bench_fn(reps, warmup_reps, cublas_matmul, a, b) - if dtype == torch.float16: - bench_fn(reps, warmup_reps, torch_matmul, a, b) - bench_fn(reps, warmup_reps, matmul, a, b.T) + #if cublas is not None: + # bench_fn(reps, warmup_reps, cublas_matmul, a, b) + #if dtype == torch.float16: + # bench_fn(reps, warmup_reps, torch_matmul, a, b) + #bench_fn(reps, warmup_reps, matmul, a, b.T) + bench_fn(reps, warmup_reps, matmul_persistent_fused, a, b.T) bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) - if supports_tma(): - bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) - bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b) + #if supports_tma(): + # bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) + # bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b) def validate(M, N, K, dtype): diff --git a/test.mlir b/test.mlir new file mode 100644 index 000000000000..be0afa4cddd7 --- /dev/null +++ b/test.mlir @@ -0,0 +1,178 @@ +#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) +module { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32> loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %c255_i32 = arith.constant 255 : i32 loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> loc(#loc1) + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c132_i32 = arith.constant 132 : i32 loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %cst_2 = arith.constant dense<0> : tensor<256xi32> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<128xi32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc63) + %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc64) + %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc65) + %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc66) + %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc67) + %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc68) + %7 = arith.muli %2, %4 : i32 loc(#loc8) + %8 = arith.muli %4, %c8_i32 : i32 loc(#loc9) + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc10) + %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc11) + %11 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc12) + %12 = tt.splat %arg3 : i32 -> tensor<128xi32> loc(#loc13) + %13 = tt.splat %arg4 : i32 -> tensor<256xi32> loc(#loc14) + %14 = tt.splat %arg6 : i32 -> tensor<128x1xi32> loc(#loc15) + %15 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> loc(#loc16) + %16 = tt.splat %arg7 : i32 -> tensor<1x256xi32> loc(#loc17) + %17 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> loc(#loc18) + %18 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc19) + %19 = tt.expand_dims %9 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc20) + %20 = tt.splat %arg8 : i32 -> tensor<128x1xi32> loc(#loc21) + %21 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc22) + %22 = tt.splat %arg3 : i32 -> tensor<128x1xi32> loc(#loc23) + %23 = tt.splat %arg4 : i32 -> tensor<1x256xi32> loc(#loc24) + scf.for %arg9 = %0 to %7 step %c132_i32 : i32 { + %24 = arith.divsi %arg9, %8 : i32 loc(#loc26) + %25 = arith.muli %24, %c8_i32 : i32 loc(#loc27) + %26 = arith.subi %2, %25 : i32 loc(#loc28) + %27 = arith.minsi %26, %c8_i32 : i32 loc(#loc29) + %28 = arith.remsi %arg9, %27 : i32 loc(#loc30) + %29 = arith.addi %25, %28 : i32 loc(#loc31) + %30 = arith.remsi %arg9, %8 : i32 loc(#loc32) + %31 = arith.divsi %30, %27 : i32 loc(#loc33) + %32 = arith.muli %29, %c128_i32 : i32 loc(#loc34) + %33 = arith.muli %31, %c256_i32 : i32 loc(#loc35) + %34 = tt.splat %32 : i32 -> tensor<128xi32> loc(#loc36) + %35 = arith.addi %34, %10 : tensor<128xi32> loc(#loc36) + %36 = tt.splat %33 : i32 -> tensor<256xi32> loc(#loc37) + %37 = arith.addi %36, %11 : tensor<256xi32> loc(#loc37) + %38 = arith.cmpi slt, %35, %12 : tensor<128xi32> loc(#loc13) + %39 = arith.select %38, %35, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> loc(#loc38) + %40 = arith.cmpi slt, %37, %13 : tensor<256xi32> loc(#loc14) + %41 = arith.select %40, %37, %cst_2 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> loc(#loc39) + %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc40) + %43 = arith.muli %42, %14 : tensor<128x1xi32> loc(#loc15) + %44 = tt.broadcast %43 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc41) + %45 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc42) + %46 = arith.muli %45, %16 : tensor<1x256xi32> loc(#loc17) + %47 = tt.broadcast %46 : tensor<1x256xi32> -> tensor<64x256xi32> loc(#loc43) + %48 = scf.for %arg10 = %c0_i32 to %6 step %c1_i32 iter_args(%arg11 = %cst) -> (tensor<128x256xf32>) : i32 { + %62 = arith.muli %arg10, %c64_i32 : i32 loc(#loc45) + %63 = tt.splat %62 : i32 -> tensor<64xi32> loc(#loc46) + %64 = arith.addi %63, %9 : tensor<64xi32> loc(#loc46) + %65 = tt.expand_dims %64 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc47) + %66 = tt.broadcast %65 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc41) + %67 = arith.addi %44, %66 : tensor<128x64xi32> loc(#loc41) + %68 = tt.addptr %15, %67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc16) + %69 = tt.expand_dims %64 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc48) + %70 = tt.broadcast %69 : tensor<64x1xi32> -> tensor<64x256xi32> loc(#loc43) + %71 = arith.addi %70, %47 : tensor<64x256xi32> loc(#loc43) + %72 = tt.addptr %17, %71 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> loc(#loc18) + %73 = arith.subi %arg5, %62 : i32 loc(#loc49) + %74 = tt.splat %73 : i32 -> tensor<1x64xi32> loc(#loc50) + %75 = arith.cmpi slt, %18, %74 : tensor<1x64xi32> loc(#loc50) + %76 = tt.broadcast %75 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc51) + %77 = tt.load %68, %76, %cst_1 : tensor<128x64x!tt.ptr> loc(#loc51) + %78 = tt.splat %73 : i32 -> tensor<64x1xi32> loc(#loc52) + %79 = arith.cmpi slt, %19, %78 : tensor<64x1xi32> loc(#loc52) + %80 = tt.broadcast %79 : tensor<64x1xi1> -> tensor<64x256xi1> loc(#loc53) + %81 = tt.load %72, %80, %cst_0 : tensor<64x256x!tt.ptr> loc(#loc53) + %82 = tt.dot %77, %81, %arg11, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> loc(#loc54) + scf.yield %82 : tensor<128x256xf32> loc(#loc55) + } loc(#loc44) + %49 = tt.expand_dims %35 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc56) + %50 = arith.muli %20, %49 : tensor<128x1xi32> loc(#loc21) + %51 = tt.addptr %21, %50 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc22) + %52 = tt.expand_dims %37 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc57) + %53 = tt.broadcast %51 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> loc(#loc58) + %54 = tt.broadcast %52 : tensor<1x256xi32> -> tensor<128x256xi32> loc(#loc58) + %55 = tt.addptr %53, %54 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> loc(#loc58) + %56 = arith.cmpi slt, %49, %22 : tensor<128x1xi32> loc(#loc23) + %57 = arith.cmpi slt, %52, %23 : tensor<1x256xi32> loc(#loc24) + %58 = tt.broadcast %56 : tensor<128x1xi1> -> tensor<128x256xi1> loc(#loc59) + %59 = tt.broadcast %57 : tensor<1x256xi1> -> tensor<128x256xi1> loc(#loc59) + %60 = arith.andi %58, %59 : tensor<128x256xi1> loc(#loc59) + %61 = arith.truncf %48 : tensor<128x256xf32> to tensor<128x256xf16> loc(#loc60) + tt.store %55, %61, %60 : tensor<128x256x!tt.ptr> loc(#loc61) + } loc(#loc25) + tt.return loc(#loc62) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) +#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) +#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) +#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) +#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) +#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) +#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":171:38) +#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:35) +#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:41) +#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:41) +#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:37) +#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:37) +#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:49) +#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:30) +#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:79) +#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:30) +#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:53) +#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:53) +#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:37) +#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:25) +#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:37) +#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:62) +#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:47) +#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":176:30) +#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":177:33) +#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:39) +#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:52) +#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:41) +#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:31) +#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:27) +#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:48) +#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":182:26) +#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":183:26) +#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:28) +#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:28) +#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:49) +#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:49) +#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:38) +#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:61) +#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:68) +#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:60) +#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:24) +#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:26) +#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:41) +#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:68) +#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) +#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:68) +#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:64) +#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:24) +#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:64) +#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:24) +#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:39) +#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:12) +#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:45) +#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:76) +#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:56) +#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:43) +#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:31) +#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:25) +#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:4) +#loc63 = loc(callsite(#loc3 at #loc4)) +#loc64 = loc(callsite(#loc5 at #loc4)) +#loc65 = loc(callsite(#loc3 at #loc6)) +#loc66 = loc(callsite(#loc5 at #loc6)) +#loc67 = loc(callsite(#loc3 at #loc7)) +#loc68 = loc(callsite(#loc5 at #loc7)) + diff --git a/test2.mlir b/test2.mlir new file mode 100644 index 000000000000..425c15288bee --- /dev/null +++ b/test2.mlir @@ -0,0 +1,128 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.muli %4, %c8_i32 : i32 + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %14 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %15 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %16 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %17 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %18 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %19 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %20 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %21 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %23 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %24 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %25 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %26 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + scf.for %arg9 = %0 to %7 step %c132_i32 : i32 { + %27 = arith.divsi %arg9, %8 : i32 + %28 = arith.muli %27, %c8_i32 : i32 + %29 = arith.subi %2, %28 : i32 + %30 = arith.minsi %29, %c8_i32 : i32 + %31 = arith.remsi %arg9, %30 : i32 + %32 = arith.addi %28, %31 : i32 + %33 = arith.remsi %arg9, %8 : i32 + %34 = arith.divsi %33, %30 : i32 + %35 = arith.muli %32, %c128_i32 : i32 + %36 = arith.muli %34, %c256_i32 : i32 + %37 = tt.splat %35 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %38 = tt.splat %35 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %39 = arith.addi %37, %11 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %40 = arith.addi %38, %12 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %41 = tt.splat %36 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %42 = tt.splat %36 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %43 = arith.addi %41, %13 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %44 = arith.addi %42, %14 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %45 = arith.cmpi slt, %39, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %46 = arith.select %45, %39, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %47 = arith.cmpi slt, %43, %16 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %48 = arith.select %47, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %49 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %50 = arith.muli %49, %17 : tensor<128x1xi32, #blocked1> + %51 = tt.broadcast %50 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %52 = tt.expand_dims %48 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %53 = arith.muli %52, %19 : tensor<1x256xi32, #blocked> + %54 = tt.broadcast %53 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %55 = scf.for %arg10 = %c0_i32 to %6 step %c1_i32 iter_args(%arg11 = %cst_3) -> (tensor<128x256xf32, #mma>) : i32 { + %70 = arith.muli %arg10, %c64_i32 : i32 + %71 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %72 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %73 = arith.addi %71, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %74 = arith.addi %72, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %75 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %76 = tt.broadcast %75 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %77 = arith.addi %51, %76 : tensor<128x64xi32, #blocked1> + %78 = tt.addptr %18, %77 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %79 = tt.expand_dims %74 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %80 = tt.broadcast %79 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %81 = arith.addi %80, %54 : tensor<64x256xi32, #blocked> + %82 = tt.addptr %20, %81 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %83 = arith.subi %arg5, %70 : i32 + %84 = tt.splat %83 : i32 -> tensor<1x64xi32, #blocked1> + %85 = arith.cmpi slt, %21, %84 : tensor<1x64xi32, #blocked1> + %86 = tt.broadcast %85 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %87 = tt.load %78, %86, %cst_1 : tensor<128x64x!tt.ptr, #blocked1> + %88 = ttg.local_alloc %87 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %89 = tt.splat %83 : i32 -> tensor<64x1xi32, #blocked> + %90 = arith.cmpi slt, %22, %89 : tensor<64x1xi32, #blocked> + %91 = tt.broadcast %90 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %92 = tt.load %82, %91, %cst_2 : tensor<64x256x!tt.ptr, #blocked> + %93 = ttg.local_alloc %92 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %94 = ttng.warp_group_dot %88, %93, %arg11 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> + scf.yield %94 : tensor<128x256xf32, #mma> + } + %56 = tt.expand_dims %40 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %57 = arith.muli %23, %56 : tensor<128x1xi32, #blocked2> + %58 = tt.addptr %24, %57 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %59 = tt.expand_dims %44 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %60 = tt.broadcast %58 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %61 = tt.broadcast %59 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %62 = tt.addptr %60, %61 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %63 = arith.cmpi slt, %56, %25 : tensor<128x1xi32, #blocked2> + %64 = arith.cmpi slt, %59, %26 : tensor<1x256xi32, #blocked2> + %65 = tt.broadcast %63 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %66 = tt.broadcast %64 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %67 = arith.andi %65, %66 : tensor<128x256xi1, #blocked2> + %68 = arith.truncf %55 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %69 = ttg.convert_layout %68 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %62, %69, %67 : tensor<128x256x!tt.ptr, #blocked2> + } + tt.return + } +} + diff --git a/test3.mlir b/test3.mlir new file mode 100644 index 000000000000..a81b31ca4aa2 --- /dev/null +++ b/test3.mlir @@ -0,0 +1,177 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.muli %4, %c8_i32 : i32 + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %14 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %15 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %16 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %17 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %18 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %19 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %20 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %21 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %23 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %24 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %25 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %26 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %27 = arith.subi %7, %0 : i32 + %28 = arith.ceildivsi %27, %c132_i32 : i32 + %29 = arith.subi %6, %c0_i32 : i32 + %30 = arith.ceildivsi %29, %c1_i32 : i32 + %c0_i64 = arith.constant 0 : i64 + %31 = arith.extsi %30 : i32 to i64 + %c1_i64 = arith.constant 1 : i64 + %32 = arith.maxsi %c1_i64, %31 : i64 + %33 = arith.addi %c0_i64, %32 : i64 + %c0_i64_4 = arith.constant 0 : i64 + %34 = arith.subi %33, %c0_i64_4 : i64 + %35 = arith.extsi %28 : i32 to i64 + %36 = arith.muli %35, %34 : i64 + %c-1_i64 = arith.constant -1 : i64 + %37 = arith.subi %0, %c132_i32 : i32 + %38 = ub.poison : i32 + %39 = ub.poison : tensor<128x256xf32, #mma> + %40 = ub.poison : i32 + %41 = ub.poison : i32 + %c0_i64_5 = arith.constant 0 : i64 + %c1_i64_6 = arith.constant 1 : i64 + %42:6 = scf.for %arg9 = %c0_i64_5 to %36 step %c1_i64_6 iter_args(%arg10 = %c-1_i64, %arg11 = %37, %arg12 = %38, %arg13 = %39, %arg14 = %40, %arg15 = %41) -> (i64, i32, i32, tensor<128x256xf32, #mma>, i32, i32) : i64 { + %c1_i64_7 = arith.constant 1 : i64 + %43 = arith.addi %arg10, %c1_i64_7 : i64 + %44 = arith.remsi %43, %34 : i64 + %c0_i64_8 = arith.constant 0 : i64 + %45 = arith.subi %c0_i64, %c0_i64_8 : i64 + %46 = arith.cmpi eq, %44, %45 : i64 + %47:5 = scf.if %46 -> (i32, i32, i32, tensor<128x256xf32, #mma>, i32) { + %56 = arith.addi %arg11, %c132_i32 : i32 + %57 = arith.divsi %56, %8 : i32 + %58 = arith.muli %57, %c8_i32 : i32 + %59 = arith.subi %2, %58 : i32 + %60 = arith.minsi %59, %c8_i32 : i32 + %61 = arith.remsi %56, %60 : i32 + %62 = arith.addi %58, %61 : i32 + %63 = arith.remsi %56, %8 : i32 + %64 = arith.divsi %63, %60 : i32 + %65 = arith.muli %62, %c128_i32 : i32 + %66 = arith.muli %64, %c256_i32 : i32 + scf.yield %c0_i32, %65, %66, %cst_3, %56 : i32, i32, i32, tensor<128x256xf32, #mma>, i32 + } else { + scf.yield %arg12, %arg14, %arg15, %arg13, %arg11 : i32, i32, i32, tensor<128x256xf32, #mma>, i32 + } + %48 = arith.extsi %30 : i32 to i64 + %49 = arith.addi %45, %48 : i64 + %50 = arith.cmpi sge, %44, %45 : i64 + %51 = arith.cmpi slt, %44, %49 : i64 + %52 = arith.andi %50, %51 : i1 + %true = arith.constant true + %53:2 = scf.if %true -> (i32, tensor<128x256xf32, #mma>) { + %56 = tt.splat %47#1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %57 = arith.addi %56, %11 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %58 = tt.splat %47#2 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %59 = arith.addi %58, %13 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %60 = arith.cmpi slt, %57, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %61 = arith.select %60, %57, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %62 = arith.cmpi slt, %59, %16 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %63 = arith.select %62, %59, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %64 = tt.expand_dims %61 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %65 = arith.muli %64, %17 : tensor<128x1xi32, #blocked1> + %66 = tt.broadcast %65 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %67 = tt.expand_dims %63 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %68 = arith.muli %67, %19 : tensor<1x256xi32, #blocked> + %69 = tt.broadcast %68 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %70 = arith.muli %47#0, %c64_i32 : i32 + %71 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %72 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %73 = arith.addi %71, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %74 = arith.addi %72, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %75 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %76 = tt.broadcast %75 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %77 = arith.addi %66, %76 : tensor<128x64xi32, #blocked1> + %78 = tt.addptr %18, %77 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %79 = tt.expand_dims %74 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %80 = tt.broadcast %79 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %81 = arith.addi %80, %69 : tensor<64x256xi32, #blocked> + %82 = tt.addptr %20, %81 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %83 = arith.subi %arg5, %70 : i32 + %84 = tt.splat %83 : i32 -> tensor<1x64xi32, #blocked1> + %85 = arith.cmpi slt, %21, %84 : tensor<1x64xi32, #blocked1> + %86 = tt.broadcast %85 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %87 = tt.load %78, %86, %cst_1 : tensor<128x64x!tt.ptr, #blocked1> + %88 = ttg.local_alloc %87 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %89 = tt.splat %83 : i32 -> tensor<64x1xi32, #blocked> + %90 = arith.cmpi slt, %22, %89 : tensor<64x1xi32, #blocked> + %91 = tt.broadcast %90 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %92 = tt.load %82, %91, %cst_2 : tensor<64x256x!tt.ptr, #blocked> + %93 = ttg.local_alloc %92 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %94 = ttng.warp_group_dot %88, %93, %47#3 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> + %95 = arith.addi %47#0, %c1_i32 : i32 + scf.yield %95, %94 : i32, tensor<128x256xf32, #mma> + } else { + scf.yield %47#0, %arg13 : i32, tensor<128x256xf32, #mma> + } + %c1_i64_9 = arith.constant 1 : i64 + %54 = arith.subi %34, %c1_i64_9 : i64 + %55 = arith.cmpi eq, %44, %54 : i64 + scf.if %55 { + %56 = tt.splat %47#1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %57 = arith.addi %56, %12 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %58 = tt.splat %47#2 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %59 = arith.addi %58, %14 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %60 = tt.expand_dims %57 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %61 = arith.muli %23, %60 : tensor<128x1xi32, #blocked2> + %62 = tt.addptr %24, %61 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %63 = tt.expand_dims %59 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %64 = tt.broadcast %62 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %65 = tt.broadcast %63 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %66 = tt.addptr %64, %65 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %67 = arith.cmpi slt, %60, %25 : tensor<128x1xi32, #blocked2> + %68 = arith.cmpi slt, %63, %26 : tensor<1x256xi32, #blocked2> + %69 = tt.broadcast %67 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %70 = tt.broadcast %68 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %71 = arith.andi %69, %70 : tensor<128x256xi1, #blocked2> + %72 = arith.truncf %53#1 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %73 = ttg.convert_layout %72 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %66, %73, %71 : tensor<128x256x!tt.ptr, #blocked2> + } else { + } + scf.yield %44, %47#4, %53#0, %53#1, %47#1, %47#2 : i64, i32, i32, tensor<128x256xf32, #mma>, i32, i32 + } + tt.return + } +} + diff --git a/test4.mlir b/test4.mlir new file mode 100644 index 000000000000..01d3e533847d --- /dev/null +++ b/test4.mlir @@ -0,0 +1,192 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked> + %0 = ub.poison : tensor<64x256xi32, #blocked1> + %1 = ub.poison : tensor<128x64xi32, #blocked2> + %2 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %3 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = ub.poison : tensor<128x256xf32, #mma> + %5 = ub.poison : i32 + %c-1_i64 = arith.constant -1 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst_0 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %cst_1 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked2> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked1> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %true = arith.constant true + %false = arith.constant false + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %6 = tt.get_program_id x : i32 + %7 = arith.addi %arg3, %c127_i32 : i32 + %8 = arith.divsi %7, %c128_i32 : i32 + %9 = arith.addi %arg4, %c255_i32 : i32 + %10 = arith.divsi %9, %c256_i32 : i32 + %11 = arith.addi %arg5, %c63_i32 : i32 + %12 = arith.divsi %11, %c64_i32 : i32 + %13 = arith.muli %8, %10 : i32 + %14 = arith.muli %10, %c8_i32 : i32 + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %21 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %23 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked2> + %24 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2> + %25 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked1> + %26 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> + %27 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> + %28 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %29 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked> + %30 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %31 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked> + %32 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked> + %33 = arith.cmpi eq, %12, %c0_i32 : i32 + scf.if %33 { + scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { + %34 = arith.divsi %arg9, %14 : i32 + %35 = arith.muli %34, %c8_i32 : i32 + %36 = arith.subi %8, %35 : i32 + %37 = arith.minsi %36, %c8_i32 : i32 + %38 = arith.remsi %arg9, %37 : i32 + %39 = arith.addi %35, %38 : i32 + %40 = arith.remsi %arg9, %14 : i32 + %41 = arith.divsi %40, %37 : i32 + %42 = arith.muli %39, %c128_i32 : i32 + %43 = arith.muli %41, %c256_i32 : i32 + %44 = tt.splat %42 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %45 = arith.addi %44, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %46 = tt.splat %43 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %47 = arith.addi %46, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %48 = tt.expand_dims %45 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %49 = arith.muli %29, %48 : tensor<128x1xi32, #blocked> + %50 = tt.addptr %30, %49 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %51 = tt.expand_dims %47 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %52 = tt.broadcast %50 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> + %53 = tt.broadcast %51 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> + %54 = tt.addptr %52, %53 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + %55 = arith.cmpi slt, %48, %31 : tensor<128x1xi32, #blocked> + %56 = arith.cmpi slt, %51, %32 : tensor<1x256xi32, #blocked> + %57 = tt.broadcast %55 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> + %58 = tt.broadcast %56 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> + %59 = arith.andi %57, %58 : tensor<128x256xi1, #blocked> + tt.store %54, %cst, %59 : tensor<128x256x!tt.ptr, #blocked> + } + } else { + %34 = arith.subi %13, %6 : i32 + %35 = arith.ceildivsi %34, %c132_i32 : i32 + %36 = arith.extsi %12 : i32 to i64 + %37 = arith.maxsi %36, %c1_i64 : i64 + %38 = arith.extsi %35 : i32 to i64 + %39 = arith.muli %38, %37 : i64 + %40 = arith.subi %6, %c132_i32 : i32 + %41:9 = scf.for %arg9 = %c0_i64 to %39 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %40, %arg12 = %5, %arg13 = %4, %arg14 = %3, %arg15 = %2, %arg16 = %1, %arg17 = %0, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i1) : i64 { + %42 = arith.addi %arg10, %c1_i64 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i64 + %43 = arith.remsi %42, %37 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i64 + %44 = arith.cmpi eq, %43, %c0_i64 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i64 + %45:7 = scf.if %44 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1) { + %74 = arith.addi %arg11, %c132_i32 : i32 + %75 = arith.divsi %74, %14 : i32 + %76 = arith.muli %75, %c8_i32 : i32 + %77 = arith.subi %8, %76 : i32 + %78 = arith.minsi %77, %c8_i32 : i32 + %79 = arith.remsi %74, %78 : i32 + %80 = arith.addi %76, %79 : i32 + %81 = arith.remsi %74, %14 : i32 + %82 = arith.divsi %81, %78 : i32 + %83 = arith.muli %80, %c128_i32 : i32 + %84 = arith.muli %82, %c256_i32 : i32 + %85 = tt.splat %83 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %86 = tt.splat %83 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %87 = arith.addi %85, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %88 = arith.addi %86, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %89 = tt.splat %84 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %90 = tt.splat %84 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %91 = arith.addi %89, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %92 = arith.addi %90, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %93 = arith.cmpi slt, %87, %21 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %94 = arith.select %93, %87, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %95 = arith.cmpi slt, %91, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %96 = arith.select %95, %91, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %97 = tt.expand_dims %94 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %98 = arith.muli %97, %23 : tensor<128x1xi32, #blocked2> + %99 = tt.broadcast %98 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %100 = tt.expand_dims %96 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %101 = arith.muli %100, %25 : tensor<1x256xi32, #blocked1> + %102 = tt.broadcast %101 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + scf.yield %88, %92, %99, %102, %74, %c0_i32, %false : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1 + } else { + scf.yield %arg14, %arg15, %arg16, %arg17, %arg11, %arg12, %arg18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1 + } {loop.cluster = 1 : i32, loop.stage = 0 : i32} + %46 = arith.muli %45#5, %c64_i32 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 + %47 = tt.splat %46 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %48 = tt.splat %46 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %49 = arith.addi %47, %15 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %50 = arith.addi %48, %16 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %51 = tt.expand_dims %49 {axis = 0 : i32, loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> + %52 = tt.broadcast %51 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %53 = arith.addi %45#2, %52 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x64xi32, #blocked2> + %54 = tt.addptr %24, %53 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %55 = tt.expand_dims %50 {axis = 1 : i32, loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %56 = tt.broadcast %55 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %57 = arith.addi %56, %45#3 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x256xi32, #blocked1> + %58 = tt.addptr %26, %57 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %59 = arith.subi %arg5, %46 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 + %60 = tt.splat %59 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<1x64xi32, #blocked2> + %61 = arith.cmpi slt, %27, %60 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<1x64xi32, #blocked2> + %62 = tt.broadcast %61 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> + %63 = tt.load %54, %62, %cst_2 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr, #blocked2> + %64 = ttg.local_alloc %63 {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %65 = tt.splat %59 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<64x1xi32, #blocked1> + %66 = arith.cmpi slt, %28, %65 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x1xi32, #blocked1> + %67 = tt.broadcast %66 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %68 = tt.load %58, %67, %cst_3 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x256x!tt.ptr, #blocked1> + %69 = ttg.local_alloc %68 {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %70 = ttng.warp_group_dot %64, %69, %arg13, %45#6 {inputPrecision = 0 : i32, loop.cluster = 2 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> + %71 = arith.addi %45#5, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : i32 + %72 = arith.subi %37, %c1_i64 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i64 + %73 = arith.cmpi eq, %43, %72 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i64 + scf.if %73 { + %74 = tt.expand_dims %45#0 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %75 = arith.muli %29, %74 : tensor<128x1xi32, #blocked> + %76 = tt.addptr %30, %75 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %77 = tt.expand_dims %45#1 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %78 = tt.broadcast %76 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> + %79 = tt.broadcast %77 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> + %80 = tt.addptr %78, %79 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + %81 = arith.cmpi slt, %74, %31 : tensor<128x1xi32, #blocked> + %82 = arith.cmpi slt, %77, %32 : tensor<1x256xi32, #blocked> + %83 = tt.broadcast %81 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> + %84 = tt.broadcast %82 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> + %85 = arith.andi %83, %84 : tensor<128x256xi1, #blocked> + %86 = arith.truncf %70 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %87 = ttg.convert_layout %86 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> + tt.store %80, %87, %85 : tensor<128x256x!tt.ptr, #blocked> + } {loop.cluster = 5 : i32, loop.stage = 2 : i32} + scf.yield %43, %45#4, %71, %70, %45#0, %45#1, %45#2, %45#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i1 + } + } + tt.return + } +} + diff --git a/test5.mlir b/test5.mlir new file mode 100644 index 000000000000..07d0108d2182 --- /dev/null +++ b/test5.mlir @@ -0,0 +1,345 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c2_i64 = arith.constant 2 : i64 + %c3_i32 = arith.constant 3 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked> + %0 = ub.poison : tensor<64x256xi32, #blocked1> + %1 = ub.poison : tensor<128x64xi32, #blocked2> + %2 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %3 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = ub.poison : tensor<128x256xf32, #mma> + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst_0 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %cst_1 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked2> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked1> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %5 = tt.get_program_id x : i32 + %6 = arith.addi %arg3, %c127_i32 : i32 + %7 = arith.divsi %6, %c128_i32 : i32 + %8 = arith.addi %arg4, %c255_i32 : i32 + %9 = arith.divsi %8, %c256_i32 : i32 + %10 = arith.addi %arg5, %c63_i32 : i32 + %11 = arith.divsi %10, %c64_i32 : i32 + %12 = arith.muli %7, %9 : i32 + %13 = arith.muli %9, %c8_i32 : i32 + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %20 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %21 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked2> + %23 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2> + %24 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked1> + %25 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> + %26 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> + %27 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %28 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked> + %29 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %30 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked> + %31 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked> + %32 = arith.cmpi eq, %11, %c0_i32 : i32 + scf.if %32 { + scf.for %arg9 = %5 to %12 step %c132_i32 : i32 { + %33 = arith.divsi %arg9, %13 : i32 + %34 = arith.muli %33, %c8_i32 : i32 + %35 = arith.subi %7, %34 : i32 + %36 = arith.minsi %35, %c8_i32 : i32 + %37 = arith.remsi %arg9, %36 : i32 + %38 = arith.addi %34, %37 : i32 + %39 = arith.remsi %arg9, %13 : i32 + %40 = arith.divsi %39, %36 : i32 + %41 = arith.muli %38, %c128_i32 : i32 + %42 = arith.muli %40, %c256_i32 : i32 + %43 = tt.splat %41 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %44 = arith.addi %43, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %46 = arith.addi %45, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %47 = tt.expand_dims %44 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %48 = arith.muli %28, %47 : tensor<128x1xi32, #blocked> + %49 = tt.addptr %29, %48 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %50 = tt.expand_dims %46 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %51 = tt.broadcast %49 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> + %52 = tt.broadcast %50 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> + %53 = tt.addptr %51, %52 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + %54 = arith.cmpi slt, %47, %30 : tensor<128x1xi32, #blocked> + %55 = arith.cmpi slt, %50, %31 : tensor<1x256xi32, #blocked> + %56 = tt.broadcast %54 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> + %57 = tt.broadcast %55 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> + %58 = arith.andi %56, %57 : tensor<128x256xi1, #blocked> + tt.store %53, %cst, %58 : tensor<128x256x!tt.ptr, #blocked> + } + } else { + %33 = arith.subi %12, %5 : i32 + %34 = arith.ceildivsi %33, %c132_i32 : i32 + %35 = arith.extsi %11 : i32 to i64 + %36 = arith.maxsi %35, %c1_i64 : i64 + %37 = arith.extsi %34 : i32 to i64 + %38 = arith.muli %37, %36 : i64 + %39 = arith.subi %5, %c132_i32 : i32 + %40 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + %41 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %42 = arith.cmpi sgt, %38, %c0_i64 : i64 + %43 = arith.remsi %c0_i64, %36 : i64 + %44 = arith.cmpi eq, %43, %c0_i64 : i64 + %45:5 = scf.if %44 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32) { + %108 = arith.divsi %5, %13 : i32 + %109 = arith.muli %108, %c8_i32 : i32 + %110 = arith.subi %7, %109 : i32 + %111 = arith.minsi %110, %c8_i32 : i32 + %112 = arith.remsi %5, %111 : i32 + %113 = arith.addi %109, %112 : i32 + %114 = arith.remsi %5, %13 : i32 + %115 = arith.divsi %114, %111 : i32 + %116 = arith.muli %113, %c128_i32 : i32 + %117 = arith.muli %115, %c256_i32 : i32 + %118 = tt.splat %116 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %119 = tt.splat %116 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %120 = arith.addi %118, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %121 = arith.addi %119, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %122 = tt.splat %117 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %123 = tt.splat %117 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %124 = arith.addi %122, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %125 = arith.addi %123, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %126 = arith.cmpi slt, %120, %20 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %127 = arith.select %126, %120, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %128 = arith.cmpi slt, %124, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %129 = arith.select %128, %124, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %130 = tt.expand_dims %127 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %131 = arith.muli %130, %22 : tensor<128x1xi32, #blocked2> + %132 = tt.broadcast %131 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %133 = tt.expand_dims %129 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %134 = arith.muli %133, %24 : tensor<1x256xi32, #blocked1> + %135 = tt.broadcast %134 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + scf.yield %121, %125, %132, %135, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 + } else { + scf.yield %3, %2, %1, %0, %39 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 + } + %46 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> + %47 = tt.broadcast %46 : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %48 = arith.addi %45#2, %47 : tensor<128x64xi32, #blocked2> + %49 = tt.addptr %23, %48 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %50 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %51 = tt.broadcast %50 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %52 = arith.addi %51, %45#3 : tensor<64x256xi32, #blocked1> + %53 = tt.addptr %25, %52 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %54 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked2> + %55 = arith.cmpi slt, %26, %54 : tensor<1x64xi32, #blocked2> + %56 = tt.broadcast %55 : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> + %57 = ttg.memdesc_subview %40[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> + %58 = tt.splat %42 : i1 -> tensor<128x64xi1, #blocked2> + %59 = arith.andi %58, %56 : tensor<128x64xi1, #blocked2> + %60 = ttg.async_copy_global_to_local %49, %57 mask %59 other %cst_2 : tensor<128x64x!tt.ptr, #blocked2> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> + %61 = ttg.async_commit_group %60 + %62 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked1> + %63 = arith.cmpi slt, %27, %62 : tensor<64x1xi32, #blocked1> + %64 = tt.broadcast %63 : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %65 = ttg.memdesc_subview %41[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> + %66 = tt.splat %42 : i1 -> tensor<64x256xi1, #blocked1> + %67 = arith.andi %66, %64 : tensor<64x256xi1, #blocked1> + %68 = ttg.async_copy_global_to_local %53, %65 mask %67 other %cst_3 : tensor<64x256x!tt.ptr, #blocked1> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> + %69 = ttg.async_commit_group %68 + %70 = arith.cmpi sgt, %38, %c1_i64 : i64 + %71 = arith.addi %43, %c1_i64 : i64 + %72 = arith.remsi %71, %36 : i64 + %73 = arith.cmpi eq, %72, %c0_i64 : i64 + %74:5 = scf.if %73 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32) { + %108 = arith.addi %45#4, %c132_i32 : i32 + %109 = arith.divsi %108, %13 : i32 + %110 = arith.muli %109, %c8_i32 : i32 + %111 = arith.subi %7, %110 : i32 + %112 = arith.minsi %111, %c8_i32 : i32 + %113 = arith.remsi %108, %112 : i32 + %114 = arith.addi %110, %113 : i32 + %115 = arith.remsi %108, %13 : i32 + %116 = arith.divsi %115, %112 : i32 + %117 = arith.muli %114, %c128_i32 : i32 + %118 = arith.muli %116, %c256_i32 : i32 + %119 = tt.splat %117 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %120 = tt.splat %117 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %121 = arith.addi %119, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %122 = arith.addi %120, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %123 = tt.splat %118 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %124 = tt.splat %118 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %125 = arith.addi %123, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %126 = arith.addi %124, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %127 = arith.cmpi slt, %121, %20 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %128 = arith.select %127, %121, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %129 = arith.cmpi slt, %125, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %130 = arith.select %129, %125, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %131 = tt.expand_dims %128 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %132 = arith.muli %131, %22 : tensor<128x1xi32, #blocked2> + %133 = tt.broadcast %132 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %134 = tt.expand_dims %130 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %135 = arith.muli %134, %24 : tensor<1x256xi32, #blocked1> + %136 = tt.broadcast %135 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + scf.yield %122, %126, %133, %136, %108 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 + } else { + scf.yield %45#0, %45#1, %45#2, %45#3, %45#4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 + } + %75 = arith.select %73, %c0_i32, %c1_i32 : i32 + %76 = arith.muli %75, %c64_i32 : i32 + %77 = tt.splat %76 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %78 = tt.splat %76 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %79 = arith.addi %77, %14 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %80 = arith.addi %78, %15 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %81 = tt.expand_dims %79 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> + %82 = tt.broadcast %81 : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %83 = arith.addi %74#2, %82 : tensor<128x64xi32, #blocked2> + %84 = tt.addptr %23, %83 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %85 = tt.expand_dims %80 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %86 = tt.broadcast %85 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %87 = arith.addi %86, %74#3 : tensor<64x256xi32, #blocked1> + %88 = tt.addptr %25, %87 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %89 = arith.subi %arg5, %76 : i32 + %90 = tt.splat %89 : i32 -> tensor<1x64xi32, #blocked2> + %91 = arith.cmpi slt, %26, %90 : tensor<1x64xi32, #blocked2> + %92 = tt.broadcast %91 : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> + %93 = ttg.memdesc_subview %40[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> + %94 = tt.splat %70 : i1 -> tensor<128x64xi1, #blocked2> + %95 = arith.andi %94, %92 : tensor<128x64xi1, #blocked2> + %96 = ttg.async_copy_global_to_local %84, %93 mask %95 other %cst_2 : tensor<128x64x!tt.ptr, #blocked2> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> + %97 = ttg.async_commit_group %96 + %98 = tt.splat %89 : i32 -> tensor<64x1xi32, #blocked1> + %99 = arith.cmpi slt, %27, %98 : tensor<64x1xi32, #blocked1> + %100 = tt.broadcast %99 : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %101 = ttg.memdesc_subview %41[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> + %102 = tt.splat %70 : i1 -> tensor<64x256xi1, #blocked1> + %103 = arith.andi %102, %100 : tensor<64x256xi1, #blocked1> + %104 = ttg.async_copy_global_to_local %88, %101 mask %103 other %cst_3 : tensor<64x256x!tt.ptr, #blocked1> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> + %105 = ttg.async_commit_group %104 + %106:23 = scf.for %arg9 = %c0_i64 to %38 step %c1_i64 iter_args(%arg10 = %72, %arg11 = %74#4, %arg12 = %c1_i32, %arg13 = %4, %arg14 = %74#0, %arg15 = %74#1, %arg16 = %74#2, %arg17 = %74#3, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %44, %arg21 = %73, %arg22 = %61, %arg23 = %97, %arg24 = %69, %arg25 = %105, %arg26 = %75, %arg27 = %43, %arg28 = %72, %arg29 = %45#0, %arg30 = %74#0, %arg31 = %45#1, %arg32 = %74#1) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1, i1, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, i64, i64, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) : i64 { + %108 = arith.subi %38, %c2_i64 : i64 + %109 = arith.cmpi slt, %arg9, %108 : i64 + %110 = arith.addi %arg10, %c1_i64 : i64 + %111 = arith.remsi %110, %36 : i64 + %112 = arith.cmpi eq, %111, %c0_i64 : i64 + %113:5 = scf.if %112 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32) { + %160 = arith.addi %arg11, %c132_i32 : i32 + %161 = arith.divsi %160, %13 : i32 + %162 = arith.muli %161, %c8_i32 : i32 + %163 = arith.subi %7, %162 : i32 + %164 = arith.minsi %163, %c8_i32 : i32 + %165 = arith.remsi %160, %164 : i32 + %166 = arith.addi %162, %165 : i32 + %167 = arith.remsi %160, %13 : i32 + %168 = arith.divsi %167, %164 : i32 + %169 = arith.muli %166, %c128_i32 : i32 + %170 = arith.muli %168, %c256_i32 : i32 + %171 = tt.splat %169 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %172 = tt.splat %169 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %173 = arith.addi %171, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %174 = arith.addi %172, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %175 = tt.splat %170 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %176 = tt.splat %170 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %177 = arith.addi %175, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %178 = arith.addi %176, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %179 = arith.cmpi slt, %173, %20 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %180 = arith.select %179, %173, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %181 = arith.cmpi slt, %177, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %182 = arith.select %181, %177, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %183 = tt.expand_dims %180 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %184 = arith.muli %183, %22 : tensor<128x1xi32, #blocked2> + %185 = tt.broadcast %184 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %186 = tt.expand_dims %182 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %187 = arith.muli %186, %24 : tensor<1x256xi32, #blocked1> + %188 = tt.broadcast %187 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + scf.yield %174, %178, %185, %188, %160 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 + } else { + scf.yield %arg14, %arg15, %arg16, %arg17, %arg11 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 + } + %114 = arith.addi %arg19, %c1_i32 : i32 + %115 = arith.cmpi slt, %114, %c3_i32 : i32 + %116 = arith.select %115, %114, %c0_i32 : i32 + %117 = arith.select %arg20, %cst_4, %arg13 : tensor<128x256xf32, #mma> + %118 = ttg.memdesc_subview %40[%116, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> + %119 = ttg.async_wait %arg24 {num = 2 : i32} + %120 = ttg.memdesc_subview %41[%116, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> + %121 = ttng.warp_group_dot %118, %120, %117 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> + %122:3 = ttng.warp_group_dot_wait %121, %118, %120 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> + %123 = arith.addi %arg26, %c1_i32 : i32 + %124 = arith.addi %arg18, %c1_i32 : i32 + %125 = arith.cmpi slt, %124, %c3_i32 : i32 + %126 = arith.select %125, %124, %c0_i32 : i32 + %127 = arith.select %112, %c0_i32, %123 : i32 + %128 = arith.muli %127, %c64_i32 : i32 + %129 = tt.splat %128 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %130 = tt.splat %128 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %131 = arith.addi %129, %14 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %132 = arith.addi %130, %15 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %133 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> + %134 = tt.broadcast %133 : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> + %135 = arith.addi %113#2, %134 : tensor<128x64xi32, #blocked2> + %136 = tt.addptr %23, %135 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %137 = tt.expand_dims %132 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %138 = tt.broadcast %137 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %139 = arith.addi %138, %113#3 : tensor<64x256xi32, #blocked1> + %140 = tt.addptr %25, %139 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %141 = arith.subi %arg5, %128 : i32 + %142 = tt.splat %141 : i32 -> tensor<1x64xi32, #blocked2> + %143 = arith.cmpi slt, %26, %142 : tensor<1x64xi32, #blocked2> + %144 = tt.broadcast %143 : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> + %145 = ttg.memdesc_subview %40[%126, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> + %146 = tt.splat %109 : i1 -> tensor<128x64xi1, #blocked2> + %147 = arith.andi %146, %144 : tensor<128x64xi1, #blocked2> + %148 = ttg.async_copy_global_to_local %136, %145 mask %147 other %cst_2 : tensor<128x64x!tt.ptr, #blocked2> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> + %149 = ttg.async_commit_group %148 + %150 = tt.splat %141 : i32 -> tensor<64x1xi32, #blocked1> + %151 = arith.cmpi slt, %27, %150 : tensor<64x1xi32, #blocked1> + %152 = tt.broadcast %151 : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %153 = ttg.memdesc_subview %41[%126, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> + %154 = tt.splat %109 : i1 -> tensor<64x256xi1, #blocked1> + %155 = arith.andi %154, %152 : tensor<64x256xi1, #blocked1> + %156 = ttg.async_copy_global_to_local %140, %153 mask %155 other %cst_3 : tensor<64x256x!tt.ptr, #blocked1> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> + %157 = ttg.async_commit_group %156 + %158 = arith.subi %36, %c1_i64 : i64 + %159 = arith.cmpi eq, %arg27, %158 : i64 + scf.if %159 { + %160 = tt.expand_dims %arg29 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %161 = arith.muli %28, %160 : tensor<128x1xi32, #blocked> + %162 = tt.addptr %29, %161 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %163 = tt.expand_dims %arg31 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %164 = tt.broadcast %162 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> + %165 = tt.broadcast %163 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> + %166 = tt.addptr %164, %165 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + %167 = arith.cmpi slt, %160, %30 : tensor<128x1xi32, #blocked> + %168 = arith.cmpi slt, %163, %31 : tensor<1x256xi32, #blocked> + %169 = tt.broadcast %167 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> + %170 = tt.broadcast %168 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> + %171 = arith.andi %169, %170 : tensor<128x256xi1, #blocked> + %172 = arith.truncf %122#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %173 = ttg.convert_layout %172 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> + tt.store %166, %173, %171 : tensor<128x256x!tt.ptr, #blocked> + } + scf.yield %111, %113#4, %123, %122#0, %113#0, %113#1, %113#2, %113#3, %126, %116, %arg21, %112, %arg23, %149, %arg25, %157, %127, %arg28, %111, %arg30, %113#0, %arg32, %113#1 : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1, i1, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, i64, i64, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } + %107 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %40 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %41 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + } + tt.return + } +} + diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index c29c325cf8b9..cf36320b828c 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -242,8 +242,6 @@ def make_ttgir(mod, metadata, options): "num_stages == 0. Now it will not happen anymore; " "please update to use num_stages == 2 for " "equivalent behavior in the past.") - passes.ttgpuir.add_fuse_nested_loops(pm) - passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, stream_prefetch) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.insert_instruction_sched_hints(pm) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 2411e85dcaa9..d2ce5bdb4abd 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -216,7 +216,7 @@ def make_ttir(mod, metadata, opt): passes.ttir.add_combine(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) - passes.common.add_licm(pm) + #passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) pm.run(mod) @@ -251,12 +251,13 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.common.add_cse(pm) if capability // 10 >= 8: - passes.ttgpuir.add_optimize_accumulator_init(pm) - passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_fuse_nested_loops(pm) passes.common.add_canonicalizer(pm) + passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) + passes.common.add_canonicalizer(pm) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.ttgpuir.add_coalesce_async_copy(pm) From 8d3f10834c5cd14f23fa241c5e54a13037952b63 Mon Sep 17 00:00:00 2001 From: Mogball Date: Sat, 25 Jan 2025 02:10:05 -0500 Subject: [PATCH 06/32] omg it works --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 41 +- new.mlir | 659 +++++++++--------- new2.mlir | 309 ++++++++ new3.mlir | 291 ++++++++ orig.mlir | 617 ++++++++-------- orig2.mlir | 224 +++--- python/tutorials/09-persistent-matmul.py | 458 ++++++------ third_party/nvidia/backend/compiler.py | 2 + 8 files changed, 1609 insertions(+), 992 deletions(-) create mode 100644 new2.mlir create mode 100644 new3.mlir diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index a0899815939c..3f5d8a9461a1 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -321,6 +321,19 @@ static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value, return b.create(type, value); } +static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) { + Type elTy = getElementTypeOrSelf(type); + if (!elTy.isIntOrIndexOrFloat() || + (!isa(type) && type != elTy)) + return b.create(type); + + TypedAttr attr = isa(elTy) ? TypedAttr(b.getFloatAttr(elTy, 0)) + : b.getIntegerAttr(elTy, 0); + if (auto tensor = dyn_cast(type)) + attr = SplatElementsAttr::get(tensor, attr); + return b.create(attr); +} + // Given a one level loop nest in the form // // for i in range(lbi, ubi, stepi): @@ -349,7 +362,7 @@ static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value, // T = -1 // i = lbi - stepi // for _ in range(total_iters): -// T = (T + 1) % inner_len +// T = 0 if T == (inner_len - 1) else T + 1 // // if T == 0: // i += stepi @@ -511,7 +524,6 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { intTyWidth = std::max(intTyWidth, getIntTypeWidth(lenInner.getType())); lenInners.push_back(lenInner); } - intTyWidth = std::min(64u, intTyWidth * 2); auto intTy = b.getIntegerType(intTyWidth); auto intTyCst = [&](int64_t v) { @@ -576,17 +588,17 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { unsigned ivarStartIdx = fusedInits.size(); for (scf::ForOp loop : innerLoops) { fusedInits.push_back( - b.create(loop.getInductionVar().getType())); + createPoisonOrZero(b, loop.getInductionVar().getType())); } unsigned innerOutsStartIdx = fusedInits.size(); for (scf::ForOp loop : innerLoops) { for (Type resultType : loop.getResultTypes()) - fusedInits.push_back(b.create(resultType)); + fusedInits.push_back(createPoisonOrZero(b, resultType)); } unsigned logueOutsStartIdx = fusedInits.size(); for (Logue &logue : logues) { for (Type outputType : logue.getOutputTypes()) - fusedInits.push_back(b.create(outputType)); + fusedInits.push_back(createPoisonOrZero(b, outputType)); } // for _ in range(total_iters): @@ -600,10 +612,13 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { } b.setInsertionPointToStart(fused.getBody()); - // T = (T + 1) % inner_len + // T = 0 if T == (inner_len - 1) else T + 1 Value T = fused.getRegionIterArg(0); - T = b.create(T, intTyCst(1)); - T = b.create(T, innerLen); + Value nextT = b.create(T, intTyCst(1)); + Value rollover = + b.create(arith::CmpIPredicate::eq, T, + b.create(innerLen, intTyCst(1))); + T = b.create(rollover, intTyCst(0), nextT); // `i` is computed inside the first prologue. Value curI = fused.getRegionIterArg(1); @@ -918,11 +933,11 @@ static void sinkHeavyOps(scf::ForOp outerLoop, scf::ForOp innerLoop, auto inInnerLoop = [&](Operation *op) { return innerLoop.getBodyRegion().isAncestor(op->getParentRegion()); }; - //sinkHeavyOps(limit, innerLoop.getBody(), innerLoop.getBody()->begin(), - // {outerLoop.getBody()->begin(), innerLoop->getIterator()}, - // inInnerLoop, [&](size_t fanInSize, size_t fanOutSize) { - // return fanInSize * 4 <= fanOutSize; - // }); + // sinkHeavyOps(limit, innerLoop.getBody(), innerLoop.getBody()->begin(), + // {outerLoop.getBody()->begin(), innerLoop->getIterator()}, + // inInnerLoop, [&](size_t fanInSize, size_t fanOutSize) { + // return fanInSize * 4 <= fanOutSize; + // }); // Move computations in the prologue that can be done in the epilogue. This is // always beneficial. diff --git a/new.mlir b/new.mlir index 24bf404cb01f..7852f8582e95 100644 --- a/new.mlir +++ b/new.mlir @@ -1,20 +1,16 @@ #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0) +#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0) #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":231:0)) attributes {noinline = false} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0)) attributes {noinline = false} { %c2_i64 = arith.constant 2 : i64 loc(#loc1) %c3_i32 = arith.constant 3 : i32 loc(#loc1) %c-1_i32 = arith.constant -1 : i32 loc(#loc1) - %0 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - %1 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) - %2 = ub.poison : tensor<128x256xf32, #mma> loc(#loc1) - %3 = ub.poison : i32 loc(#loc1) %c1_i64 = arith.constant 1 : i64 loc(#loc1) %c0_i64 = arith.constant 0 : i64 loc(#loc1) %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) @@ -31,353 +27,348 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %c127_i32 = arith.constant 127 : i32 loc(#loc1) %c255_i32 = arith.constant 255 : i32 loc(#loc1) %c63_i32 = arith.constant 63 : i32 loc(#loc1) - %4 = tt.get_program_id x : i32 loc(#loc2) - %5 = arith.addi %arg3, %c127_i32 : i32 loc(#loc59) - %6 = arith.divsi %5, %c128_i32 : i32 loc(#loc60) - %7 = arith.addi %arg4, %c255_i32 : i32 loc(#loc61) - %8 = arith.divsi %7, %c256_i32 : i32 loc(#loc62) - %9 = arith.addi %arg5, %c63_i32 : i32 loc(#loc63) - %10 = arith.divsi %9, %c64_i32 : i32 loc(#loc64) - %11 = arith.muli %6, %8 : i32 loc(#loc8) - %12 = arith.muli %8, %c8_i32 : i32 loc(#loc9) - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10) - %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc10) - %15 = arith.subi %11, %4 : i32 loc(#loc11) - %16 = arith.ceildivsi %15, %c132_i32 : i32 loc(#loc11) - %17 = arith.extsi %10 : i32 to i64 loc(#loc11) - %18 = arith.maxsi %17, %c1_i64 : i64 loc(#loc11) - %19 = arith.extsi %16 : i32 to i64 loc(#loc11) - %20 = arith.muli %19, %18 : i64 loc(#loc11) - %21 = arith.subi %4, %c132_i32 : i32 loc(#loc11) - %22 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc12) - %23 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc13) - %24 = arith.cmpi sgt, %20, %c0_i64 : i64 loc(#loc11) - %25 = arith.remsi %c0_i64, %18 : i64 loc(#loc11) - %26 = arith.cmpi eq, %25, %c0_i64 : i64 loc(#loc11) - %27 = arith.select %26, %4, %21 : i32 loc(#loc11) - %28 = arith.cmpi ne, %25, %c0_i64 : i64 loc(#loc11) - %29:4 = scf.if %26 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %110 = arith.divsi %4, %12 : i32 loc(#loc14) - %111 = arith.muli %110, %c8_i32 : i32 loc(#loc15) - %112 = arith.subi %6, %111 : i32 loc(#loc16) - %113 = arith.minsi %112, %c8_i32 : i32 loc(#loc17) - %114 = arith.remsi %4, %113 : i32 loc(#loc18) - %115 = arith.addi %111, %114 : i32 loc(#loc19) - %116 = arith.remsi %4, %12 : i32 loc(#loc20) - %117 = arith.divsi %116, %113 : i32 loc(#loc21) - %118 = arith.muli %115, %c128_i32 : i32 loc(#loc22) - %119 = arith.muli %117, %c256_i32 : i32 loc(#loc23) - %120 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) - %121 = tt.splat %118 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %122 = arith.addi %121, %120 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %123 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) - %124 = tt.splat %119 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %125 = arith.addi %124, %123 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %126 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %127 = arith.cmpi slt, %122, %126 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %128 = arith.select %127, %122, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) - %129 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %130 = arith.cmpi slt, %125, %129 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %131 = arith.select %130, %125, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) - scf.yield %118, %119, %128, %131 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc59) + %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc60) + %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc61) + %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc62) + %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc63) + %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc64) + %7 = arith.muli %2, %4 : i32 loc(#loc8) + %8 = arith.muli %4, %c8_i32 : i32 loc(#loc9) + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10) + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc10) + %11 = arith.subi %7, %0 : i32 loc(#loc11) + %12 = arith.ceildivsi %11, %c132_i32 : i32 loc(#loc11) + %13 = arith.extsi %6 : i32 to i64 loc(#loc11) + %14 = arith.maxsi %13, %c1_i64 : i64 loc(#loc11) + %15 = arith.extsi %12 : i32 to i64 loc(#loc11) + %16 = arith.muli %15, %14 : i64 loc(#loc11) + %17 = arith.subi %0, %c132_i32 : i32 loc(#loc11) + %18 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc12) + %19 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc13) + %20 = arith.cmpi sgt, %16, %c0_i64 : i64 loc(#loc11) + %21 = arith.remsi %c0_i64, %14 : i64 loc(#loc11) + %22 = arith.cmpi eq, %21, %c0_i64 : i64 loc(#loc11) + %23 = arith.select %22, %0, %17 : i32 loc(#loc11) + %24:4 = scf.if %22 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %105 = arith.divsi %0, %8 : i32 loc(#loc14) + %106 = arith.muli %105, %c8_i32 : i32 loc(#loc15) + %107 = arith.subi %2, %106 : i32 loc(#loc16) + %108 = arith.minsi %107, %c8_i32 : i32 loc(#loc17) + %109 = arith.remsi %0, %108 : i32 loc(#loc18) + %110 = arith.addi %106, %109 : i32 loc(#loc19) + %111 = arith.remsi %0, %8 : i32 loc(#loc20) + %112 = arith.divsi %111, %108 : i32 loc(#loc21) + %113 = arith.muli %110, %c128_i32 : i32 loc(#loc22) + %114 = arith.muli %112, %c256_i32 : i32 loc(#loc23) + %115 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) + %116 = tt.splat %113 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %117 = arith.addi %116, %115 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %118 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) + %119 = tt.splat %114 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %120 = arith.addi %119, %118 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %121 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %122 = arith.cmpi slt, %117, %121 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %123 = arith.select %122, %117, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) + %124 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %125 = arith.cmpi slt, %120, %124 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %126 = arith.select %125, %120, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) + scf.yield %113, %114, %123, %126 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) } else { - scf.yield %3, %3, %1, %0 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) + scf.yield %c0_i32, %c0_i32, %cst_0, %cst : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) } loc(#loc11) - %30 = tt.expand_dims %29#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) - %31 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc33) - %32 = arith.muli %30, %31 : tensor<128x1xi32, #blocked1> loc(#loc33) - %33 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) - %34 = tt.broadcast %32 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %35 = tt.broadcast %33 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %36 = arith.addi %34, %35 : tensor<128x64xi32, #blocked1> loc(#loc35) - %37 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc36) - %38 = tt.addptr %37, %36 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) - %39 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) - %40 = tt.expand_dims %29#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) - %41 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc39) - %42 = arith.muli %40, %41 : tensor<1x256xi32, #blocked> loc(#loc39) - %43 = tt.broadcast %39 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %44 = tt.broadcast %42 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %45 = arith.addi %43, %44 : tensor<64x256xi32, #blocked> loc(#loc40) - %46 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc41) - %47 = tt.addptr %46, %45 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) - %48 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) - %49 = arith.cmpi slt, %33, %48 : tensor<1x64xi32, #blocked1> loc(#loc42) - %50 = tt.broadcast %49 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) - %51 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %52 = tt.splat %24 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) - %53 = arith.andi %52, %50 : tensor<128x64xi1, #blocked1> loc(#loc11) - %54 = ttg.async_copy_global_to_local %38, %51 mask %53 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %55 = ttg.async_commit_group %54 loc(#loc12) - %56 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) - %57 = arith.cmpi slt, %39, %56 : tensor<64x1xi32, #blocked> loc(#loc43) - %58 = tt.broadcast %57 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) - %59 = ttg.memdesc_subview %23[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %60 = tt.splat %24 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) - %61 = arith.andi %60, %58 : tensor<64x256xi1, #blocked> loc(#loc11) - %62 = ttg.async_copy_global_to_local %47, %59 mask %61 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %63 = ttg.async_commit_group %62 loc(#loc13) - %64 = arith.cmpi sgt, %20, %c1_i64 : i64 loc(#loc11) - %65 = arith.addi %25, %c1_i64 : i64 loc(#loc11) - %66 = arith.remsi %65, %18 : i64 loc(#loc11) - %67 = arith.cmpi eq, %66, %c0_i64 : i64 loc(#loc11) - %68 = arith.cmpi ne, %66, %c0_i64 : i64 loc(#loc11) - %69 = arith.extui %68 : i1 to i32 loc(#loc11) - %70:5 = scf.if %67 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %110 = arith.addi %27, %c132_i32 : i32 loc(#loc11) - %111 = arith.divsi %110, %12 : i32 loc(#loc14) - %112 = arith.muli %111, %c8_i32 : i32 loc(#loc15) - %113 = arith.subi %6, %112 : i32 loc(#loc16) - %114 = arith.minsi %113, %c8_i32 : i32 loc(#loc17) - %115 = arith.remsi %110, %114 : i32 loc(#loc18) - %116 = arith.addi %112, %115 : i32 loc(#loc19) - %117 = arith.remsi %110, %12 : i32 loc(#loc20) - %118 = arith.divsi %117, %114 : i32 loc(#loc21) - %119 = arith.muli %116, %c128_i32 : i32 loc(#loc22) - %120 = arith.muli %118, %c256_i32 : i32 loc(#loc23) - %121 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) - %122 = tt.splat %119 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %123 = arith.addi %122, %121 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %124 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) - %125 = tt.splat %120 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %126 = arith.addi %125, %124 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %127 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %128 = arith.cmpi slt, %123, %127 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %129 = arith.select %128, %123, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) - %130 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %131 = arith.cmpi slt, %126, %130 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %132 = arith.select %131, %126, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) - scf.yield %119, %120, %129, %132, %110 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) + %25 = tt.expand_dims %24#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) + %26 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc33) + %27 = arith.muli %25, %26 : tensor<128x1xi32, #blocked1> loc(#loc33) + %28 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) + %29 = tt.broadcast %27 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %30 = tt.broadcast %28 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %31 = arith.addi %29, %30 : tensor<128x64xi32, #blocked1> loc(#loc35) + %32 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc36) + %33 = tt.addptr %32, %31 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) + %34 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) + %35 = tt.expand_dims %24#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) + %36 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc39) + %37 = arith.muli %35, %36 : tensor<1x256xi32, #blocked> loc(#loc39) + %38 = tt.broadcast %34 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %39 = tt.broadcast %37 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %40 = arith.addi %38, %39 : tensor<64x256xi32, #blocked> loc(#loc40) + %41 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc41) + %42 = tt.addptr %41, %40 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) + %43 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) + %44 = arith.cmpi slt, %28, %43 : tensor<1x64xi32, #blocked1> loc(#loc42) + %45 = tt.broadcast %44 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) + %46 = ttg.memdesc_subview %18[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %47 = tt.splat %20 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) + %48 = arith.andi %47, %45 : tensor<128x64xi1, #blocked1> loc(#loc11) + %49 = ttg.async_copy_global_to_local %33, %46 mask %48 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %50 = ttg.async_commit_group %49 loc(#loc12) + %51 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) + %52 = arith.cmpi slt, %34, %51 : tensor<64x1xi32, #blocked> loc(#loc43) + %53 = tt.broadcast %52 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) + %54 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %55 = tt.splat %20 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) + %56 = arith.andi %55, %53 : tensor<64x256xi1, #blocked> loc(#loc11) + %57 = ttg.async_copy_global_to_local %42, %54 mask %56 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %58 = ttg.async_commit_group %57 loc(#loc13) + %59 = arith.cmpi sgt, %16, %c1_i64 : i64 loc(#loc11) + %60 = arith.addi %21, %c1_i64 : i64 loc(#loc11) + %61 = arith.remsi %60, %14 : i64 loc(#loc11) + %62 = arith.cmpi eq, %61, %c0_i64 : i64 loc(#loc11) + %63 = arith.cmpi ne, %61, %c0_i64 : i64 loc(#loc11) + %64 = arith.extui %63 : i1 to i32 loc(#loc11) + %65:5 = scf.if %62 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %105 = arith.addi %23, %c132_i32 : i32 loc(#loc11) + %106 = arith.divsi %105, %8 : i32 loc(#loc14) + %107 = arith.muli %106, %c8_i32 : i32 loc(#loc15) + %108 = arith.subi %2, %107 : i32 loc(#loc16) + %109 = arith.minsi %108, %c8_i32 : i32 loc(#loc17) + %110 = arith.remsi %105, %109 : i32 loc(#loc18) + %111 = arith.addi %107, %110 : i32 loc(#loc19) + %112 = arith.remsi %105, %8 : i32 loc(#loc20) + %113 = arith.divsi %112, %109 : i32 loc(#loc21) + %114 = arith.muli %111, %c128_i32 : i32 loc(#loc22) + %115 = arith.muli %113, %c256_i32 : i32 loc(#loc23) + %116 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) + %117 = tt.splat %114 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %118 = arith.addi %117, %116 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %119 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) + %120 = tt.splat %115 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %121 = arith.addi %120, %119 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %122 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %123 = arith.cmpi slt, %118, %122 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %124 = arith.select %123, %118, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) + %125 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %126 = arith.cmpi slt, %121, %125 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %127 = arith.select %126, %121, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) + scf.yield %114, %115, %124, %127, %105 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) } else { - scf.yield %29#0, %29#1, %29#2, %29#3, %27 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) + scf.yield %24#0, %24#1, %24#2, %24#3, %23 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) } loc(#loc11) - %71 = arith.muli %69, %c64_i32 : i32 loc(#loc44) - %72 = tt.splat %71 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %73 = tt.splat %71 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %74 = arith.addi %72, %13 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %75 = arith.addi %73, %14 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %76 = tt.expand_dims %70#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) - %77 = arith.muli %76, %31 : tensor<128x1xi32, #blocked1> loc(#loc33) - %78 = tt.expand_dims %74 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) - %79 = tt.broadcast %77 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %80 = tt.broadcast %78 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %81 = arith.addi %79, %80 : tensor<128x64xi32, #blocked1> loc(#loc35) - %82 = tt.addptr %37, %81 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) - %83 = tt.expand_dims %75 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) - %84 = tt.expand_dims %70#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) - %85 = arith.muli %84, %41 : tensor<1x256xi32, #blocked> loc(#loc39) - %86 = tt.broadcast %83 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %87 = tt.broadcast %85 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %88 = arith.addi %86, %87 : tensor<64x256xi32, #blocked> loc(#loc40) - %89 = tt.addptr %46, %88 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) - %90 = arith.subi %arg5, %71 : i32 loc(#loc46) - %91 = tt.splat %90 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) - %92 = arith.cmpi slt, %33, %91 : tensor<1x64xi32, #blocked1> loc(#loc42) - %93 = tt.broadcast %92 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) - %94 = ttg.memdesc_subview %22[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %95 = tt.splat %64 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) - %96 = arith.andi %95, %93 : tensor<128x64xi1, #blocked1> loc(#loc11) - %97 = ttg.async_copy_global_to_local %82, %94 mask %96 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %98 = ttg.async_commit_group %97 loc(#loc12) - %99 = tt.splat %90 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) - %100 = arith.cmpi slt, %39, %99 : tensor<64x1xi32, #blocked> loc(#loc43) - %101 = tt.broadcast %100 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) - %102 = ttg.memdesc_subview %23[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %103 = tt.splat %64 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) - %104 = arith.andi %103, %101 : tensor<64x256xi1, #blocked> loc(#loc11) - %105 = ttg.async_copy_global_to_local %89, %102 mask %104 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %106 = ttg.async_commit_group %105 loc(#loc13) - %107:20 = scf.for %arg9 = %c0_i64 to %20 step %c1_i64 iter_args(%arg10 = %66, %arg11 = %70#4, %arg12 = %2, %arg13 = %70#0, %arg14 = %70#1, %arg15 = %70#2, %arg16 = %70#3, %arg17 = %c1_i32, %arg18 = %c-1_i32, %arg19 = %69, %arg20 = %63, %arg21 = %106, %arg22 = %28, %arg23 = %68, %arg24 = %25, %arg25 = %66, %arg26 = %29#0, %arg27 = %70#0, %arg28 = %29#1, %arg29 = %70#1) -> ( - i64, i32, - tensor<128x256xf32, #mma>, - i32, i32, - tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, - tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, - i32, i32, i32, !ttg.async.token, !ttg.async.token, i1, i1, i64, i64, i32, i32, i32, i32) : i64 { - %110 = arith.subi %20, %c2_i64 : i64 loc(#loc11) - %111 = arith.cmpi slt, %arg9, %110 : i64 loc(#loc11) - %112 = arith.addi %arg19, %c1_i32 : i32 loc(#loc11) - %113 = arith.addi %arg10, %c1_i64 : i64 loc(#loc11) - %114 = arith.remsi %113, %18 : i64 loc(#loc11) - %115 = arith.cmpi eq, %114, %c0_i64 : i64 loc(#loc11) - %116 = arith.select %115, %c0_i32, %112 : i32 loc(#loc11) - %117 = arith.cmpi ne, %114, %c0_i64 : i64 loc(#loc11) - %118:5 = scf.if %115 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %168 = arith.addi %arg11, %c132_i32 : i32 loc(#loc11) - %169 = arith.divsi %168, %12 : i32 loc(#loc14) - %170 = arith.muli %169, %c8_i32 : i32 loc(#loc15) - %171 = arith.subi %6, %170 : i32 loc(#loc16) - %172 = arith.minsi %171, %c8_i32 : i32 loc(#loc17) - %173 = arith.remsi %168, %172 : i32 loc(#loc18) - %174 = arith.addi %170, %173 : i32 loc(#loc19) - %175 = arith.remsi %168, %12 : i32 loc(#loc20) - %176 = arith.divsi %175, %172 : i32 loc(#loc21) - %177 = arith.muli %174, %c128_i32 : i32 loc(#loc22) - %178 = arith.muli %176, %c256_i32 : i32 loc(#loc23) - %179 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) - %180 = tt.splat %177 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %181 = arith.addi %180, %179 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %182 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) - %183 = tt.splat %178 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %184 = arith.addi %183, %182 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %185 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %186 = arith.cmpi slt, %181, %185 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %187 = arith.select %186, %181, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) - %188 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %189 = arith.cmpi slt, %184, %188 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %190 = arith.select %189, %184, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) - scf.yield %177, %178, %187, %190, %168 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) + %66 = arith.muli %64, %c64_i32 : i32 loc(#loc44) + %67 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %68 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %69 = arith.addi %67, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %70 = arith.addi %68, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %71 = tt.expand_dims %65#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) + %72 = arith.muli %71, %26 : tensor<128x1xi32, #blocked1> loc(#loc33) + %73 = tt.expand_dims %69 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) + %74 = tt.broadcast %72 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %75 = tt.broadcast %73 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %76 = arith.addi %74, %75 : tensor<128x64xi32, #blocked1> loc(#loc35) + %77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) + %78 = tt.expand_dims %70 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) + %79 = tt.expand_dims %65#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) + %80 = arith.muli %79, %36 : tensor<1x256xi32, #blocked> loc(#loc39) + %81 = tt.broadcast %78 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %82 = tt.broadcast %80 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %83 = arith.addi %81, %82 : tensor<64x256xi32, #blocked> loc(#loc40) + %84 = tt.addptr %41, %83 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) + %85 = arith.subi %arg5, %66 : i32 loc(#loc46) + %86 = tt.splat %85 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) + %87 = arith.cmpi slt, %28, %86 : tensor<1x64xi32, #blocked1> loc(#loc42) + %88 = tt.broadcast %87 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) + %89 = ttg.memdesc_subview %18[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %90 = tt.splat %59 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) + %91 = arith.andi %90, %88 : tensor<128x64xi1, #blocked1> loc(#loc11) + %92 = ttg.async_copy_global_to_local %77, %89 mask %91 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %93 = ttg.async_commit_group %92 loc(#loc12) + %94 = tt.splat %85 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) + %95 = arith.cmpi slt, %34, %94 : tensor<64x1xi32, #blocked> loc(#loc43) + %96 = tt.broadcast %95 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) + %97 = ttg.memdesc_subview %19[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %98 = tt.splat %59 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) + %99 = arith.andi %98, %96 : tensor<64x256xi1, #blocked> loc(#loc11) + %100 = ttg.async_copy_global_to_local %84, %97 mask %99 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %101 = ttg.async_commit_group %100 loc(#loc13) + %102:18 = scf.for %arg9 = %c0_i64 to %16 step %c1_i64 iter_args(%arg10 = %61, %arg11 = %65#4, %arg12 = %cst_3, %arg13 = %65#0, %arg14 = %65#1, %arg15 = %65#2, %arg16 = %65#3, %arg17 = %c1_i32, %arg18 = %c-1_i32, %arg19 = %64, %arg20 = %21, %arg21 = %61, %arg22 = %58, %arg23 = %101, %arg24 = %24#0, %arg25 = %65#0, %arg26 = %24#1, %arg27 = %65#1) -> (i64, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i64, i64, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i64 { + %105 = arith.subi %16, %c2_i64 : i64 loc(#loc11) + %106 = arith.cmpi slt, %arg9, %105 : i64 loc(#loc11) + %107 = arith.addi %arg19, %c1_i32 : i32 loc(#loc11) + %108 = arith.addi %arg10, %c1_i64 : i64 loc(#loc11) + %109 = arith.remsi %108, %14 : i64 loc(#loc11) + %110 = arith.cmpi eq, %109, %c0_i64 : i64 loc(#loc11) + %111 = arith.select %110, %c0_i32, %107 : i32 loc(#loc11) + %112:5 = scf.if %110 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %163 = arith.addi %arg11, %c132_i32 : i32 loc(#loc11) + %164 = arith.divsi %163, %8 : i32 loc(#loc14) + %165 = arith.muli %164, %c8_i32 : i32 loc(#loc15) + %166 = arith.subi %2, %165 : i32 loc(#loc16) + %167 = arith.minsi %166, %c8_i32 : i32 loc(#loc17) + %168 = arith.remsi %163, %167 : i32 loc(#loc18) + %169 = arith.addi %165, %168 : i32 loc(#loc19) + %170 = arith.remsi %163, %8 : i32 loc(#loc20) + %171 = arith.divsi %170, %167 : i32 loc(#loc21) + %172 = arith.muli %169, %c128_i32 : i32 loc(#loc22) + %173 = arith.muli %171, %c256_i32 : i32 loc(#loc23) + %174 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) + %175 = tt.splat %172 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %176 = arith.addi %175, %174 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) + %177 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) + %178 = tt.splat %173 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %179 = arith.addi %178, %177 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) + %180 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %181 = arith.cmpi slt, %176, %180 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) + %182 = arith.select %181, %176, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) + %183 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %184 = arith.cmpi slt, %179, %183 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) + %185 = arith.select %184, %179, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) + scf.yield %172, %173, %182, %185, %163 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) } else { scf.yield %arg13, %arg14, %arg15, %arg16, %arg11 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) } loc(#loc11) - %119 = arith.addi %arg18, %c1_i32 : i32 loc(#loc11) - %120 = arith.cmpi slt, %119, %c3_i32 : i32 loc(#loc11) - %121 = arith.select %120, %119, %c0_i32 : i32 loc(#loc11) - %122 = ttg.memdesc_subview %22[%121, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %123 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc12) - %124 = ttg.memdesc_subview %23[%121, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %125 = ttng.warp_group_dot %122, %124, %arg12, %arg22 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc47) - %126:3 = ttng.warp_group_dot_wait %125, %122, %124 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) - %127 = arith.addi %arg17, %c1_i32 : i32 loc(#loc11) - %128 = arith.cmpi slt, %127, %c3_i32 : i32 loc(#loc11) - %129 = arith.select %128, %127, %c0_i32 : i32 loc(#loc11) - %130 = arith.muli %116, %c64_i32 : i32 loc(#loc44) - %131 = tt.splat %130 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %132 = tt.splat %130 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %133 = arith.addi %131, %13 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %134 = arith.addi %132, %14 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %135 = tt.expand_dims %118#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) - %136 = arith.muli %135, %31 : tensor<128x1xi32, #blocked1> loc(#loc33) - %137 = tt.expand_dims %133 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) - %138 = tt.broadcast %136 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %139 = tt.broadcast %137 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %140 = arith.addi %138, %139 : tensor<128x64xi32, #blocked1> loc(#loc35) - %141 = tt.addptr %37, %140 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) - %142 = tt.expand_dims %134 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) - %143 = tt.expand_dims %118#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) - %144 = arith.muli %143, %41 : tensor<1x256xi32, #blocked> loc(#loc39) - %145 = tt.broadcast %142 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %146 = tt.broadcast %144 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %147 = arith.addi %145, %146 : tensor<64x256xi32, #blocked> loc(#loc40) - %148 = tt.addptr %46, %147 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) - %149 = arith.subi %arg5, %130 : i32 loc(#loc46) - %150 = tt.splat %149 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) - %151 = arith.cmpi slt, %33, %150 : tensor<1x64xi32, #blocked1> loc(#loc42) - %152 = tt.broadcast %151 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) - %153 = ttg.memdesc_subview %22[%129, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %154 = tt.splat %111 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) - %155 = arith.andi %154, %152 : tensor<128x64xi1, #blocked1> loc(#loc11) - %156 = ttg.async_copy_global_to_local %141, %153 mask %155 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %157 = ttg.async_commit_group %156 loc(#loc12) - %158 = tt.splat %149 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) - %159 = arith.cmpi slt, %39, %158 : tensor<64x1xi32, #blocked> loc(#loc43) - %160 = tt.broadcast %159 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) - %161 = ttg.memdesc_subview %23[%129, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %162 = tt.splat %111 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) - %163 = arith.andi %162, %160 : tensor<64x256xi1, #blocked> loc(#loc11) - %164 = ttg.async_copy_global_to_local %148, %161 mask %163 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %165 = ttg.async_commit_group %164 loc(#loc13) - %166 = arith.subi %18, %c1_i64 : i64 loc(#loc11) - %167 = arith.cmpi eq, %arg24, %166 : i64 loc(#loc11) - scf.if %167 { - %168:3 = ttng.warp_group_dot_wait %126#0, %122, %124 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) - %169 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc24) - %170 = tt.splat %arg26 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) - %171 = arith.addi %170, %169 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) - %172 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc26) - %173 = tt.splat %arg28 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) - %174 = arith.addi %173, %172 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) - %175 = tt.expand_dims %171 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc48) - %176 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc49) - %177 = arith.muli %176, %175 : tensor<128x1xi32, #blocked2> loc(#loc49) - %178 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc50) - %179 = tt.addptr %178, %177 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc50) - %180 = tt.expand_dims %174 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc51) - %181 = tt.broadcast %179 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc52) - %182 = tt.broadcast %180 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc52) - %183 = tt.addptr %181, %182 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc52) - %184 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc53) - %185 = arith.cmpi slt, %175, %184 : tensor<128x1xi32, #blocked2> loc(#loc53) - %186 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc54) - %187 = arith.cmpi slt, %180, %186 : tensor<1x256xi32, #blocked2> loc(#loc54) - %188 = tt.broadcast %185 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) - %189 = tt.broadcast %187 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) - %190 = arith.andi %188, %189 : tensor<128x256xi1, #blocked2> loc(#loc55) - %191 = arith.truncf %168#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc56) - %192 = ttg.convert_layout %191 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc57) - tt.store %183, %192, %190 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc57) + %113 = arith.addi %arg18, %c1_i32 : i32 loc(#loc11) + %114 = arith.cmpi slt, %113, %c3_i32 : i32 loc(#loc11) + %115 = arith.select %114, %113, %c0_i32 : i32 loc(#loc11) + %116 = arith.cmpi ne, %arg20, %c0_i64 : i64 loc(#loc65) + %117 = ttg.memdesc_subview %18[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %118 = ttg.async_wait %arg22 {num = 2 : i32} loc(#loc12) + %119 = ttg.memdesc_subview %19[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %120 = ttng.warp_group_dot %117, %119, %arg12, %116 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc47) + %121:3 = ttng.warp_group_dot_wait %120, %117, %119 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) + %122 = arith.addi %arg17, %c1_i32 : i32 loc(#loc11) + %123 = arith.cmpi slt, %122, %c3_i32 : i32 loc(#loc11) + %124 = arith.select %123, %122, %c0_i32 : i32 loc(#loc11) + %125 = arith.muli %111, %c64_i32 : i32 loc(#loc44) + %126 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %127 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %128 = arith.addi %126, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) + %129 = arith.addi %127, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) + %130 = tt.expand_dims %112#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) + %131 = arith.muli %130, %26 : tensor<128x1xi32, #blocked1> loc(#loc33) + %132 = tt.expand_dims %128 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) + %133 = tt.broadcast %131 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %134 = tt.broadcast %132 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) + %135 = arith.addi %133, %134 : tensor<128x64xi32, #blocked1> loc(#loc35) + %136 = tt.addptr %32, %135 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) + %137 = tt.expand_dims %129 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) + %138 = tt.expand_dims %112#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) + %139 = arith.muli %138, %36 : tensor<1x256xi32, #blocked> loc(#loc39) + %140 = tt.broadcast %137 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %141 = tt.broadcast %139 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) + %142 = arith.addi %140, %141 : tensor<64x256xi32, #blocked> loc(#loc40) + %143 = tt.addptr %41, %142 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) + %144 = arith.subi %arg5, %125 : i32 loc(#loc46) + %145 = tt.splat %144 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) + %146 = arith.cmpi slt, %28, %145 : tensor<1x64xi32, #blocked1> loc(#loc42) + %147 = tt.broadcast %146 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) + %148 = ttg.memdesc_subview %18[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %149 = tt.splat %106 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) + %150 = arith.andi %149, %147 : tensor<128x64xi1, #blocked1> loc(#loc11) + %151 = ttg.async_copy_global_to_local %136, %148 mask %150 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) + %152 = ttg.async_commit_group %151 loc(#loc12) + %153 = tt.splat %144 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) + %154 = arith.cmpi slt, %34, %153 : tensor<64x1xi32, #blocked> loc(#loc43) + %155 = tt.broadcast %154 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) + %156 = ttg.memdesc_subview %19[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %157 = tt.splat %106 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) + %158 = arith.andi %157, %155 : tensor<64x256xi1, #blocked> loc(#loc11) + %159 = ttg.async_copy_global_to_local %143, %156 mask %158 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) + %160 = ttg.async_commit_group %159 loc(#loc13) + %161 = arith.subi %14, %c1_i64 : i64 loc(#loc11) + %162 = arith.cmpi eq, %arg20, %161 : i64 loc(#loc11) + scf.if %162 { + %163:3 = ttng.warp_group_dot_wait %121#0, %117, %119 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) + %164 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc24) + %165 = tt.splat %arg24 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) + %166 = arith.addi %165, %164 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) + %167 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc26) + %168 = tt.splat %arg26 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) + %169 = arith.addi %168, %167 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) + %170 = tt.expand_dims %166 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc48) + %171 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc49) + %172 = arith.muli %171, %170 : tensor<128x1xi32, #blocked2> loc(#loc49) + %173 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc50) + %174 = tt.addptr %173, %172 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc50) + %175 = tt.expand_dims %169 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc51) + %176 = tt.broadcast %174 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc52) + %177 = tt.broadcast %175 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc52) + %178 = tt.addptr %176, %177 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc52) + %179 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc53) + %180 = arith.cmpi slt, %170, %179 : tensor<128x1xi32, #blocked2> loc(#loc53) + %181 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc54) + %182 = arith.cmpi slt, %175, %181 : tensor<1x256xi32, #blocked2> loc(#loc54) + %183 = tt.broadcast %180 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) + %184 = tt.broadcast %182 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) + %185 = arith.andi %183, %184 : tensor<128x256xi1, #blocked2> loc(#loc55) + %186 = arith.truncf %163#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc56) + %187 = ttg.convert_layout %186 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc57) + tt.store %178, %187, %185 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc57) } loc(#loc11) - scf.yield %114, %118#4, %126#0, %118#0, %118#1, %118#2, %118#3, %129, %121, %116, %arg21, %165, %arg23, %117, %arg25, %114, %arg27, %118#0, %arg29, %118#1 : i64, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, !ttg.async.token, !ttg.async.token, i1, i1, i64, i64, i32, i32, i32, i32 loc(#loc11) + scf.yield %109, %112#4, %121#0, %112#0, %112#1, %112#2, %112#3, %124, %115, %111, %arg21, %109, %arg23, %160, %arg25, %112#0, %arg27, %112#1 : i64, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i64, i64, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 loc(#loc11) } loc(#loc11) - %108 = ttng.warp_group_dot_wait %107#2 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc11) - %109 = ttg.async_wait {num = 0 : i32} loc(#loc11) - ttg.local_dealloc %22 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc11) - ttg.local_dealloc %23 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc11) + %103 = ttng.warp_group_dot_wait %102#2 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc11) + %104 = ttg.async_wait {num = 0 : i32} loc(#loc11) + ttg.local_dealloc %18 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc11) + ttg.local_dealloc %19 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc11) tt.return loc(#loc58) } loc(#loc) } loc(#loc) #loc1 = loc(unknown) -#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":242:30) +#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":281:30) #loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) -#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":243:27) +#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":282:27) #loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) -#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":244:27) -#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":245:25) -#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":246:28) -#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":247:38) -#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":249:35) -#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":251:47) -#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":273:24) -#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":274:24) -#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":252:30) -#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":253:33) -#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":254:39) -#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":254:52) -#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":255:41) -#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":255:31) -#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":256:27) -#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":256:48) -#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":258:26) -#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":259:26) -#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":260:41) -#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":260:28) -#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":261:41) -#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":261:28) -#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":262:37) -#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":262:49) -#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":263:37) -#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":263:49) -#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:38) -#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:49) -#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:68) -#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:61) -#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:30) -#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:37) -#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:68) -#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:79) -#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:60) -#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":271:30) -#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":273:64) -#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":274:64) -#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":269:26) -#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":269:41) -#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":273:68) -#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":275:39) -#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:45) -#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:37) -#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:25) -#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:76) -#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":279:56) -#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":280:37) -#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":280:62) -#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":280:43) -#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":284:31) -#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":285:25) -#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":251:4) +#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":283:27) +#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":284:25) +#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":285:28) +#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":286:38) +#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":288:35) +#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":290:47) +#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":312:24) +#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":313:24) +#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":291:30) +#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":292:33) +#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":293:39) +#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":293:52) +#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":294:41) +#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":294:31) +#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":295:27) +#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":295:48) +#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":297:26) +#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":298:26) +#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":299:41) +#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":299:28) +#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":300:41) +#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":300:28) +#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":301:37) +#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":301:49) +#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":302:37) +#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":302:49) +#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:38) +#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:49) +#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:68) +#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:61) +#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:30) +#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:37) +#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:68) +#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:79) +#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:60) +#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:30) +#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":312:64) +#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":313:64) +#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":308:26) +#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":308:41) +#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":312:68) +#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":314:39) +#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:45) +#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:37) +#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:25) +#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:76) +#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:56) +#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":319:37) +#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":319:62) +#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":319:43) +#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":323:31) +#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":324:25) +#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":290:4) #loc59 = loc(callsite(#loc3 at #loc4)) #loc60 = loc(callsite(#loc5 at #loc4)) #loc61 = loc(callsite(#loc3 at #loc6)) #loc62 = loc(callsite(#loc5 at #loc6)) #loc63 = loc(callsite(#loc3 at #loc7)) #loc64 = loc(callsite(#loc5 at #loc7)) +#loc65 = loc(fused[#loc47, #loc11]) diff --git a/new2.mlir b/new2.mlir new file mode 100644 index 000000000000..9e82a5ef1817 --- /dev/null +++ b/new2.mlir @@ -0,0 +1,309 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c2_i64 = arith.constant 2 : i32 + %c3_i32 = arith.constant 3 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i64 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i32 + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %range_1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %range_2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %range_3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %range_4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %splat_1 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %splat_2 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.muli %4, %c8_i32 : i32 + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %11 = arith.subi %7, %0 : i32 + %12 = arith.ceildivsi %11, %c132_i32 : i32 + %13 = arith.addi %6, %c0_i32 : i32 + %14 = arith.maxsi %13, %c1_i64 : i32 + %15 = arith.addi %12, %c0_i32 : i32 + %16 = arith.muli %15, %14 : i32 + %17 = arith.subi %0, %c132_i32 : i32 + %18 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + %19 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %20 = arith.cmpi sgt, %16, %c0_i64 : i32 + %21 = arith.constant 0 : i32 + %22 = arith.cmpi eq, %21, %c0_i64 : i32 + %23 = arith.select %22, %0, %17 : i32 + %24:4 = scf.if %22 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + %105 = arith.divsi %0, %8 : i32 + %106 = arith.muli %105, %c8_i32 : i32 + %107 = arith.subi %2, %106 : i32 + %108 = arith.minsi %107, %c8_i32 : i32 + %109 = arith.remsi %0, %108 : i32 + %110 = arith.addi %106, %109 : i32 + %111 = arith.remsi %0, %8 : i32 + %112 = arith.divsi %111, %108 : i32 + %113 = arith.muli %110, %c128_i32 : i32 + %114 = arith.muli %112, %c256_i32 : i32 + %115 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %116 = tt.splat %113 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %117 = arith.addi %116, %115 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %118 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %119 = tt.splat %114 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %120 = arith.addi %119, %118 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %121 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %122 = arith.cmpi slt, %117, %121 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %123 = arith.select %122, %117, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %124 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %125 = arith.cmpi slt, %120, %124 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %126 = arith.select %125, %120, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %113, %114, %123, %126 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } else { + scf.yield %c0_i32, %c0_i32, %cst_0, %cst : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + } + %25 = tt.expand_dims %24#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %26 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %27 = arith.muli %25, %26 : tensor<128x1xi32, #blocked1> + %28 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %29 = tt.broadcast %27 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %30 = tt.broadcast %28 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %31 = arith.addi %29, %30 : tensor<128x64xi32, #blocked1> + %32 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %33 = tt.addptr %32, %31 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %34 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %35 = tt.expand_dims %24#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %36 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %37 = arith.muli %35, %36 : tensor<1x256xi32, #blocked> + %38 = tt.broadcast %34 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %39 = tt.broadcast %37 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<64x256xi32, #blocked> + %41 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %42 = tt.addptr %41, %40 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %43 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %44 = arith.cmpi slt, %28, %43 : tensor<1x64xi32, #blocked1> + %45 = tt.broadcast %44 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %46 = ttg.memdesc_subview %18[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %47 = tt.splat %20 : i1 -> tensor<128x64xi1, #blocked1> + %48 = arith.andi %47, %45 : tensor<128x64xi1, #blocked1> + %49 = ttg.async_copy_global_to_local %33, %46 mask %48 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %50 = ttg.async_commit_group %49 + %51 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> + %52 = arith.cmpi slt, %34, %51 : tensor<64x1xi32, #blocked> + %53 = tt.broadcast %52 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %54 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %55 = tt.splat %20 : i1 -> tensor<64x256xi1, #blocked> + %56 = arith.andi %55, %53 : tensor<64x256xi1, #blocked> + %57 = ttg.async_copy_global_to_local %42, %54 mask %56 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %58 = ttg.async_commit_group %57 + %59 = arith.cmpi sgt, %16, %c1_i64 : i32 + %60 = arith.addi %21, %c1_i64 : i32 + %61 = arith.remsi %60, %14 : i32 + %62 = arith.cmpi eq, %61, %c0_i64 : i32 + %63 = arith.cmpi ne, %61, %c0_i64 : i32 + %64 = arith.extui %63 : i1 to i32 + %65:5 = scf.if %62 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %105 = arith.addi %23, %c132_i32 : i32 + %106 = arith.divsi %105, %8 : i32 + %107 = arith.muli %106, %c8_i32 : i32 + %108 = arith.subi %2, %107 : i32 + %109 = arith.minsi %108, %c8_i32 : i32 + %110 = arith.remsi %105, %109 : i32 + %111 = arith.addi %107, %110 : i32 + %112 = arith.remsi %105, %8 : i32 + %113 = arith.divsi %112, %109 : i32 + %114 = arith.muli %111, %c128_i32 : i32 + %115 = arith.muli %113, %c256_i32 : i32 + %116 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %117 = tt.splat %114 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %118 = arith.addi %117, %116 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %119 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %120 = tt.splat %115 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %121 = arith.addi %120, %119 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %122 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %123 = arith.cmpi slt, %118, %122 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %124 = arith.select %123, %118, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %125 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %126 = arith.cmpi slt, %121, %125 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %127 = arith.select %126, %121, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %114, %115, %124, %127, %105 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } else { + scf.yield %24#0, %24#1, %24#2, %24#3, %23 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } + %66 = arith.muli %64, %c64_i32 : i32 + %67 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %68 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %69 = arith.addi %67, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %70 = arith.addi %68, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %71 = tt.expand_dims %65#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %72 = arith.muli %71, %26 : tensor<128x1xi32, #blocked1> + %73 = tt.expand_dims %69 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %74 = tt.broadcast %72 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %75 = tt.broadcast %73 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %76 = arith.addi %74, %75 : tensor<128x64xi32, #blocked1> + %77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %78 = tt.expand_dims %70 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %79 = tt.expand_dims %65#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %80 = arith.muli %79, %36 : tensor<1x256xi32, #blocked> + %81 = tt.broadcast %78 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %82 = tt.broadcast %80 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %83 = arith.addi %81, %82 : tensor<64x256xi32, #blocked> + %84 = tt.addptr %41, %83 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %85 = arith.subi %arg5, %66 : i32 + %86 = tt.splat %85 : i32 -> tensor<1x64xi32, #blocked1> + %87 = arith.cmpi slt, %28, %86 : tensor<1x64xi32, #blocked1> + %88 = tt.broadcast %87 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %89 = ttg.memdesc_subview %18[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %90 = tt.splat %59 : i1 -> tensor<128x64xi1, #blocked1> + %91 = arith.andi %90, %88 : tensor<128x64xi1, #blocked1> + %92 = ttg.async_copy_global_to_local %77, %89 mask %91 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %93 = ttg.async_commit_group %92 + %94 = tt.splat %85 : i32 -> tensor<64x1xi32, #blocked> + %95 = arith.cmpi slt, %34, %94 : tensor<64x1xi32, #blocked> + %96 = tt.broadcast %95 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %97 = ttg.memdesc_subview %19[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %98 = tt.splat %59 : i1 -> tensor<64x256xi1, #blocked> + %99 = arith.andi %98, %96 : tensor<64x256xi1, #blocked> + %100 = ttg.async_copy_global_to_local %84, %97 mask %99 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %101 = ttg.async_commit_group %100 + %102:18 = scf.for %arg9 = %c0_i64 to %16 step %c1_i64 iter_args(%arg10 = %61, %arg11 = %65#4, %arg12 = %cst_3, %arg13 = %65#0, %arg14 = %65#1, %arg15 = %65#2, %arg16 = %65#3, %arg17 = %c1_i32, %arg18 = %c-1_i32, %arg19 = %64, %arg20 = %21, %arg21 = %61, %arg22 = %58, %arg23 = %101, %arg24 = %24#0, %arg25 = %65#0, %arg26 = %24#1, %arg27 = %65#1) -> (i32, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i32 { + %105 = arith.subi %16, %c2_i64 : i32 + %106 = arith.cmpi slt, %arg9, %105 : i32 + %107 = arith.addi %arg19, %c1_i32 : i32 + %108 = arith.addi %107, %c0_i32 : i32 + %109 = arith.remsi %108, %14 : i32 + %110 = arith.cmpi eq, %109, %c0_i64 : i32 + %111 = arith.select %110, %c0_i32, %107 : i32 + %112:5 = scf.if %110 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %163 = arith.addi %arg11, %c132_i32 : i32 + %164 = arith.divsi %163, %8 : i32 + %165 = arith.muli %164, %c8_i32 : i32 + %166 = arith.subi %2, %165 : i32 + %167 = arith.minsi %166, %c8_i32 : i32 + %168 = arith.remsi %163, %167 : i32 + %169 = arith.addi %165, %168 : i32 + %170 = arith.remsi %163, %8 : i32 + %171 = arith.divsi %170, %167 : i32 + %172 = arith.muli %169, %c128_i32 : i32 + %173 = arith.muli %171, %c256_i32 : i32 + %175 = tt.splat %172 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %176 = arith.addi %175, %range_3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %178 = tt.splat %173 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %179 = arith.addi %178, %range_4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %181 = arith.cmpi slt, %176, %splat_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %182 = arith.select %181, %176, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %184 = arith.cmpi slt, %179, %splat_2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %185 = arith.select %184, %179, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %172, %173, %182, %185, %163 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } else { + scf.yield %arg13, %arg14, %arg15, %arg16, %arg11 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } + %113 = arith.addi %arg18, %c1_i32 : i32 + %114 = arith.cmpi slt, %113, %c3_i32 : i32 + %115 = arith.select %114, %113, %c0_i32 : i32 + %116 = arith.cmpi ne, %arg20, %c0_i64 : i32 + %117 = ttg.memdesc_subview %18[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %118 = ttg.async_wait %arg22 {num = 2 : i32} + %119 = ttg.memdesc_subview %19[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %120 = ttng.warp_group_dot %117, %119, %arg12, %116 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> + %121:3 = ttng.warp_group_dot_wait %120, %117, %119 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %122 = arith.addi %arg17, %c1_i32 : i32 + %123 = arith.cmpi slt, %122, %c3_i32 : i32 + %124 = arith.select %123, %122, %c0_i32 : i32 + %125 = arith.muli %111, %c64_i32 : i32 + %126 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %127 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %128 = arith.addi %126, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %129 = arith.addi %127, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %130 = tt.expand_dims %112#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %131 = arith.muli %130, %26 : tensor<128x1xi32, #blocked1> + %132 = tt.expand_dims %128 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %133 = tt.broadcast %131 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %134 = tt.broadcast %132 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %135 = arith.addi %133, %134 : tensor<128x64xi32, #blocked1> + %136 = tt.addptr %32, %135 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %137 = tt.expand_dims %129 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %138 = tt.expand_dims %112#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %139 = arith.muli %138, %36 : tensor<1x256xi32, #blocked> + %140 = tt.broadcast %137 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %141 = tt.broadcast %139 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %142 = arith.addi %140, %141 : tensor<64x256xi32, #blocked> + %143 = tt.addptr %41, %142 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %144 = arith.subi %arg5, %125 : i32 + %145 = tt.splat %144 : i32 -> tensor<1x64xi32, #blocked1> + %146 = arith.cmpi slt, %28, %145 : tensor<1x64xi32, #blocked1> + %147 = tt.broadcast %146 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %148 = ttg.memdesc_subview %18[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %149 = tt.splat %106 : i1 -> tensor<128x64xi1, #blocked1> + %150 = arith.andi %149, %147 : tensor<128x64xi1, #blocked1> + %151 = ttg.async_copy_global_to_local %136, %148 mask %150 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %152 = ttg.async_commit_group %151 + %153 = tt.splat %144 : i32 -> tensor<64x1xi32, #blocked> + %154 = arith.cmpi slt, %34, %153 : tensor<64x1xi32, #blocked> + %155 = tt.broadcast %154 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %156 = ttg.memdesc_subview %19[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %157 = tt.splat %106 : i1 -> tensor<64x256xi1, #blocked> + %158 = arith.andi %157, %155 : tensor<64x256xi1, #blocked> + %159 = ttg.async_copy_global_to_local %143, %156 mask %158 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %160 = ttg.async_commit_group %159 + %161 = arith.subi %14, %c1_i64 : i32 + %162 = arith.cmpi eq, %arg20, %161 : i32 + scf.if %162 { + %163:3 = ttng.warp_group_dot_wait %121#0, %117, %119 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %165 = tt.splat %arg24 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %166 = arith.addi %165, %range_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %168 = tt.splat %arg26 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %169 = arith.addi %168, %range_2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %170 = tt.expand_dims %166 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %171 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %172 = arith.muli %171, %170 : tensor<128x1xi32, #blocked2> + %173 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %174 = tt.addptr %173, %172 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %175 = tt.expand_dims %169 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %176 = tt.broadcast %174 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %177 = tt.broadcast %175 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %178 = tt.addptr %176, %177 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %179 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %180 = arith.cmpi slt, %170, %179 : tensor<128x1xi32, #blocked2> + %181 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %182 = arith.cmpi slt, %175, %181 : tensor<1x256xi32, #blocked2> + %183 = tt.broadcast %180 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %184 = tt.broadcast %182 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %185 = arith.andi %183, %184 : tensor<128x256xi1, #blocked2> + %186 = arith.truncf %163#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %187 = ttg.convert_layout %186 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %178, %187, %185 : tensor<128x256x!tt.ptr, #blocked2> + } + scf.yield %109, %112#4, %121#0, %112#0, %112#1, %112#2, %112#3, %124, %115, %111, %arg21, %109, %arg23, %160, %arg25, %112#0, %arg27, %112#1 : i32, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 + } + %103 = ttng.warp_group_dot_wait %102#2 {pendings = 0 : i32} : tensor<128x256xf32, #mma> + %104 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %18 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %19 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + tt.return + } +} + diff --git a/new3.mlir b/new3.mlir new file mode 100644 index 000000000000..3b5823443c13 --- /dev/null +++ b/new3.mlir @@ -0,0 +1,291 @@ +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c3_i32 = arith.constant 3 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %5 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %6 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = arith.addi %arg3, %c127_i32 : i32 + %8 = arith.divsi %7, %c128_i32 : i32 + %9 = arith.addi %arg4, %c255_i32 : i32 + %10 = arith.divsi %9, %c256_i32 : i32 + %11 = arith.addi %arg5, %c63_i32 : i32 + %12 = arith.divsi %11, %c64_i32 : i32 + %13 = arith.muli %8, %10 : i32 + %14 = arith.muli %10, %c8_i32 : i32 + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = arith.subi %13, %0 : i32 + %18 = arith.ceildivsi %17, %c132_i32 : i32 + %19 = arith.maxsi %12, %c1_i32 : i32 + %20 = arith.muli %18, %19 : i32 + %21 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %23 = arith.cmpi sgt, %20, %c0_i32 : i32 + %24 = arith.divsi %0, %14 : i32 + %25 = arith.muli %24, %c8_i32 : i32 + %26 = arith.subi %8, %25 : i32 + %27 = arith.minsi %26, %c8_i32 : i32 + %28 = arith.remsi %0, %27 : i32 + %29 = arith.addi %25, %28 : i32 + %30 = arith.remsi %0, %14 : i32 + %31 = arith.divsi %30, %27 : i32 + %32 = arith.muli %29, %c128_i32 : i32 + %33 = arith.muli %31, %c256_i32 : i32 + %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %35 = arith.addi %34, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %37 = arith.addi %36, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %38 = arith.cmpi slt, %35, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %39 = arith.select %38, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %40 = arith.cmpi slt, %37, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %41 = arith.select %40, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %43 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %44 = arith.muli %42, %43 : tensor<128x1xi32, #blocked1> + %45 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %46 = tt.broadcast %44 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %47 = tt.broadcast %45 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %48 = arith.addi %46, %47 : tensor<128x64xi32, #blocked1> + %49 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %50 = tt.addptr %49, %48 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %51 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %52 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %53 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %54 = arith.muli %52, %53 : tensor<1x256xi32, #blocked> + %55 = tt.broadcast %51 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %56 = tt.broadcast %54 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %57 = arith.addi %55, %56 : tensor<64x256xi32, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %59 = tt.addptr %58, %57 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %60 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %61 = arith.cmpi slt, %45, %60 : tensor<1x64xi32, #blocked1> + %62 = tt.broadcast %61 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %63 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %64 = tt.splat %23 : i1 -> tensor<128x64xi1, #blocked1> + %65 = arith.andi %64, %62 : tensor<128x64xi1, #blocked1> + %66 = ttg.async_copy_global_to_local %50, %63 mask %65 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %67 = ttg.async_commit_group %66 + %68 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> + %69 = arith.cmpi slt, %51, %68 : tensor<64x1xi32, #blocked> + %70 = tt.broadcast %69 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %71 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %72 = tt.splat %23 : i1 -> tensor<64x256xi1, #blocked> + %73 = arith.andi %72, %70 : tensor<64x256xi1, #blocked> + %74 = ttg.async_copy_global_to_local %59, %71 mask %73 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %75 = ttg.async_commit_group %74 + %76 = arith.cmpi sgt, %20, %c1_i32 : i32 + %77 = arith.remsi %c1_i32, %19 : i32 + %78 = arith.cmpi eq, %77, %c0_i32 : i32 + %79 = arith.cmpi ne, %77, %c0_i32 : i32 + %80 = arith.extui %79 : i1 to i32 + %81:5 = scf.if %78 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %121 = arith.addi %0, %c132_i32 : i32 + %122 = arith.divsi %121, %14 : i32 + %123 = arith.muli %122, %c8_i32 : i32 + %124 = arith.subi %8, %123 : i32 + %125 = arith.minsi %124, %c8_i32 : i32 + %126 = arith.remsi %121, %125 : i32 + %127 = arith.addi %123, %126 : i32 + %128 = arith.remsi %121, %14 : i32 + %129 = arith.divsi %128, %125 : i32 + %130 = arith.muli %127, %c128_i32 : i32 + %131 = arith.muli %129, %c256_i32 : i32 + %132 = tt.splat %130 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %133 = arith.addi %132, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %134 = tt.splat %131 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %135 = arith.addi %134, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %136 = arith.cmpi slt, %133, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %137 = arith.select %136, %133, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %138 = arith.cmpi slt, %135, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %139 = arith.select %138, %135, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %130, %131, %137, %139, %121 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } else { + scf.yield %32, %33, %39, %41, %0 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } + %82 = arith.muli %80, %c64_i32 : i32 + %83 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %84 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %85 = arith.addi %83, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %86 = arith.addi %84, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %87 = tt.expand_dims %81#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %88 = arith.muli %87, %43 : tensor<128x1xi32, #blocked1> + %89 = tt.expand_dims %85 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %90 = tt.broadcast %88 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %91 = tt.broadcast %89 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %92 = arith.addi %90, %91 : tensor<128x64xi32, #blocked1> + %93 = tt.addptr %49, %92 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %94 = tt.expand_dims %86 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %95 = tt.expand_dims %81#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %96 = arith.muli %95, %53 : tensor<1x256xi32, #blocked> + %97 = tt.broadcast %94 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %98 = tt.broadcast %96 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %99 = arith.addi %97, %98 : tensor<64x256xi32, #blocked> + %100 = tt.addptr %58, %99 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %101 = arith.subi %arg5, %82 : i32 + %102 = tt.splat %101 : i32 -> tensor<1x64xi32, #blocked1> + %103 = arith.cmpi slt, %45, %102 : tensor<1x64xi32, #blocked1> + %104 = tt.broadcast %103 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %105 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %106 = tt.splat %76 : i1 -> tensor<128x64xi1, #blocked1> + %107 = arith.andi %106, %104 : tensor<128x64xi1, #blocked1> + %108 = ttg.async_copy_global_to_local %93, %105 mask %107 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %109 = ttg.async_commit_group %108 + %110 = tt.splat %101 : i32 -> tensor<64x1xi32, #blocked> + %111 = arith.cmpi slt, %51, %110 : tensor<64x1xi32, #blocked> + %112 = tt.broadcast %111 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %113 = ttg.memdesc_subview %22[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %114 = tt.splat %76 : i1 -> tensor<64x256xi1, #blocked> + %115 = arith.andi %114, %112 : tensor<64x256xi1, #blocked> + %116 = ttg.async_copy_global_to_local %100, %113 mask %115 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %117 = ttg.async_commit_group %116 + %lol = arith.subi %12, %c1_i32 : i32 + %118:16 = scf.for %arg9 = %c0_i32 to %20 step %c1_i32 iter_args( + %arg10 = %81#4, %arg11 = %cst_3, %arg12 = %81#0, %arg13 = %81#1, + %arg14 = %81#2, %arg15 = %81#3, %arg16 = %c1_i32, %arg17 = %c-1_i32, + %arg18 = %80, %arg19 = %c0_i32, %arg21 = %75, %arg22 = %117, %arg23 = %32, %arg24 = %81#0, %arg25 = %33, %arg26 = %81#1) -> (i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i32 { + %121 = arith.subi %20, %c2_i32 : i32 + %122 = arith.cmpi slt, %arg9, %121 : i32 + %rollover = arith.cmpi eq, %arg18, %lol : i32 + %123 = arith.addi %arg18, %c1_i32 : i32 + %126 = arith.select %rollover, %c0_i32, %123 : i32 + %125 = arith.cmpi eq, %126, %c0_i32 : i32 + %127:5 = scf.if %125 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %178 = arith.addi %arg10, %c132_i32 : i32 + %179 = arith.divsi %178, %14 : i32 + %180 = arith.muli %179, %c8_i32 : i32 + %181 = arith.subi %8, %180 : i32 + %182 = arith.minsi %181, %c8_i32 : i32 + %183 = arith.remsi %178, %182 : i32 + %184 = arith.addi %180, %183 : i32 + %185 = arith.remsi %178, %14 : i32 + %186 = arith.divsi %185, %182 : i32 + %187 = arith.muli %184, %c128_i32 : i32 + %188 = arith.muli %186, %c256_i32 : i32 + %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %190 = arith.addi %189, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %192 = arith.addi %191, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %193 = arith.cmpi slt, %190, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %195 = arith.cmpi slt, %192, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %187, %188, %194, %196, %178 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } else { + scf.yield %arg12, %arg13, %arg14, %arg15, %arg10 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 + } + %128 = arith.addi %arg17, %c1_i32 : i32 + %129 = arith.cmpi slt, %128, %c3_i32 : i32 + %130 = arith.select %129, %128, %c0_i32 : i32 + %131 = arith.cmpi ne, %arg19, %c0_i32 : i32 + %132 = ttg.memdesc_subview %21[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %133 = ttg.async_wait %arg21 {num = 2 : i32} + %134 = ttg.memdesc_subview %22[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %135 = ttng.warp_group_dot %132, %134, %arg11, %131 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> + %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %137 = arith.addi %arg16, %c1_i32 : i32 + %138 = arith.cmpi slt, %137, %c3_i32 : i32 + %139 = arith.select %138, %137, %c0_i32 : i32 + %140 = arith.muli %126, %c64_i32 : i32 + %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %143 = arith.addi %141, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %144 = arith.addi %142, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %145 = tt.expand_dims %127#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %146 = arith.muli %145, %43 : tensor<128x1xi32, #blocked1> + %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> + %151 = tt.addptr %49, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %153 = tt.expand_dims %127#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %154 = arith.muli %153, %53 : tensor<1x256xi32, #blocked> + %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> + %158 = tt.addptr %58, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %159 = arith.subi %arg5, %140 : i32 + %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> + %161 = arith.cmpi slt, %45, %160 : tensor<1x64xi32, #blocked1> + %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %163 = ttg.memdesc_subview %21[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %164 = tt.splat %122 : i1 -> tensor<128x64xi1, #blocked1> + %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> + %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %167 = ttg.async_commit_group %166 + %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> + %169 = arith.cmpi slt, %51, %168 : tensor<64x1xi32, #blocked> + %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %171 = ttg.memdesc_subview %22[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %172 = tt.splat %122 : i1 -> tensor<64x256xi1, #blocked> + %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> + %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %175 = ttg.async_commit_group %174 + %176 = arith.subi %19, %c1_i32 : i32 + %177 = arith.cmpi eq, %arg19, %176 : i32 + scf.if %177 { + %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %179 = tt.splat %arg23 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %180 = arith.addi %179, %1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %181 = tt.splat %arg25 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %182 = arith.addi %181, %2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %183 = tt.expand_dims %180 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %184 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %185 = arith.muli %184, %183 : tensor<128x1xi32, #blocked2> + %186 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %187 = tt.addptr %186, %185 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %188 = tt.expand_dims %182 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %189 = tt.broadcast %187 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %190 = tt.broadcast %188 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %191 = tt.addptr %189, %190 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %192 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %193 = arith.cmpi slt, %183, %192 : tensor<128x1xi32, #blocked2> + %194 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %195 = arith.cmpi slt, %188, %194 : tensor<1x256xi32, #blocked2> + %196 = tt.broadcast %193 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %197 = tt.broadcast %195 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %198 = arith.andi %196, %197 : tensor<128x256xi1, #blocked2> + %199 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %200 = ttg.convert_layout %199 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %191, %200, %198 : tensor<128x256x!tt.ptr, #blocked2> + } + scf.yield %127#4, %136#0, %127#0, %127#1, + %127#2, %127#3, %139, %130, + %126, %arg18, %arg22, %175, %arg24, %127#0, %arg26, %127#1 : i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 + } + %119 = ttng.warp_group_dot_wait %118#1 {pendings = 0 : i32} : tensor<128x256xf32, #mma> + %120 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %21 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + tt.return + } +} + diff --git a/orig.mlir b/orig.mlir index 9c63eb7a9f7d..0af6b4b63b38 100644 --- a/orig.mlir +++ b/orig.mlir @@ -1,13 +1,13 @@ #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) +#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0) #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { + tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0)) attributes {noinline = false} { %c2_i32 = arith.constant 2 : i32 loc(#loc1) %c3_i32 = arith.constant 3 : i32 loc(#loc1) %false = arith.constant false loc(#loc1) @@ -28,12 +28,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %c63_i32 = arith.constant 63 : i32 loc(#loc1) %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc80) - %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc81) - %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc82) - %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc83) - %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc84) - %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc85) + %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc78) + %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc79) + %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc80) + %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc81) + %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc82) + %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc83) %7 = arith.muli %2, %4 : i32 loc(#loc8) %8 = arith.divsi %7, %c132_i32 : i32 loc(#loc9) %9 = arith.remsi %7, %c132_i32 : i32 loc(#loc10) @@ -52,335 +52,328 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc17) %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc17) %19 = arith.muli %6, %11 : i32 loc(#loc18) - %20 = arith.subi %6, %c1_i32 : i32 loc(#loc19) - %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc20) - %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc21) - %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc22) - %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc23) - %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc24) - %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc25) - %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc26) - %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc27) - %29 = arith.cmpi sgt, %19, %c0_i32 : i32 loc(#loc28) - %30 = arith.divsi %0, %14 : i32 loc(#loc29) - %31 = arith.muli %30, %c8_i32 : i32 loc(#loc30) - %32 = arith.subi %2, %31 : i32 loc(#loc31) - %33 = arith.minsi %32, %c8_i32 : i32 loc(#loc32) - %34 = arith.remsi %0, %33 : i32 loc(#loc33) - %35 = arith.addi %31, %34 : i32 loc(#loc34) - %36 = arith.remsi %0, %14 : i32 loc(#loc35) - %37 = arith.divsi %36, %33 : i32 loc(#loc36) - %38 = arith.muli %35, %c128_i32 : i32 loc(#loc37) - %39 = arith.muli %37, %c256_i32 : i32 loc(#loc38) - %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) - %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) - %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) - %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> loc(#loc46) - %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) - %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) - %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> loc(#loc22) - %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> loc(#loc48) - %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) + %20 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc19) + %21 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc20) + %22 = arith.cmpi sgt, %19, %c0_i32 : i32 loc(#loc21) + %23 = arith.subi %6, %c1_i32 : i32 loc(#loc22) + %24 = arith.divsi %0, %14 : i32 loc(#loc23) + %25 = arith.muli %24, %c8_i32 : i32 loc(#loc24) + %26 = arith.subi %2, %25 : i32 loc(#loc25) + %27 = arith.minsi %26, %c8_i32 : i32 loc(#loc26) + %28 = arith.remsi %0, %27 : i32 loc(#loc27) + %29 = arith.addi %25, %28 : i32 loc(#loc28) + %30 = arith.remsi %0, %14 : i32 loc(#loc29) + %31 = arith.divsi %30, %27 : i32 loc(#loc30) + %32 = arith.muli %29, %c128_i32 : i32 loc(#loc31) + %33 = arith.muli %31, %c256_i32 : i32 loc(#loc32) + %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) + %35 = arith.addi %34, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) + %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) + %37 = arith.addi %36, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) + %38 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) + %39 = arith.cmpi slt, %35, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) + %40 = arith.select %39, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc36) + %41 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) + %42 = arith.cmpi slt, %37, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) + %43 = arith.select %42, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) + %44 = tt.expand_dims %40 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc39) + %45 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc40) + %46 = arith.muli %44, %45 : tensor<128x1xi32, #blocked1> loc(#loc40) + %47 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc41) + %48 = tt.broadcast %46 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) + %49 = tt.broadcast %47 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) + %50 = arith.addi %48, %49 : tensor<128x64xi32, #blocked1> loc(#loc42) + %51 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc43) + %52 = tt.addptr %51, %50 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc43) + %53 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc44) + %54 = tt.expand_dims %43 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc45) + %55 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc46) + %56 = arith.muli %54, %55 : tensor<1x256xi32, #blocked> loc(#loc46) + %57 = tt.broadcast %53 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) + %58 = tt.broadcast %56 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) + %59 = arith.addi %57, %58 : tensor<64x256xi32, #blocked> loc(#loc47) + %60 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc48) + %61 = tt.addptr %60, %59 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc48) %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> loc(#loc49) - %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) - %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) - %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> loc(#loc28) - %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %69 = ttg.async_commit_group %68 loc(#loc26) + %63 = arith.cmpi slt, %47, %62 : tensor<1x64xi32, #blocked1> loc(#loc49) + %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc19) + %65 = ttg.memdesc_subview %20[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) + %66 = tt.splat %22 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc21) + %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> loc(#loc21) + %68 = ttg.async_copy_global_to_local %52, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) + %69 = ttg.async_commit_group %68 loc(#loc19) %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> loc(#loc50) - %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) - %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) - %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> loc(#loc28) - %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %77 = ttg.async_commit_group %76 loc(#loc27) - %78 = arith.cmpi sgt, %19, %c1_i32 : i32 loc(#loc28) - %79 = arith.cmpi ne, %20, %c0_i32 : i32 loc(#loc86) + %71 = arith.cmpi slt, %53, %70 : tensor<64x1xi32, #blocked> loc(#loc50) + %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc20) + %73 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) + %74 = tt.splat %22 : i1 -> tensor<64x256xi1, #blocked> loc(#loc21) + %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> loc(#loc21) + %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) + %77 = ttg.async_commit_group %76 loc(#loc20) + %78 = arith.cmpi sgt, %19, %c1_i32 : i32 loc(#loc21) + %79 = arith.cmpi ne, %23, %c0_i32 : i32 loc(#loc84) %80 = arith.extui %79 : i1 to i32 loc(#loc51) %81 = arith.cmpi eq, %80, %c0_i32 : i32 loc(#loc53) %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { %122 = arith.addi %0, %c132_i32 : i32 loc(#loc55) - %123 = arith.divsi %122, %14 : i32 loc(#loc29) - %124 = arith.muli %123, %c8_i32 : i32 loc(#loc30) - %125 = arith.subi %2, %124 : i32 loc(#loc31) - %126 = arith.minsi %125, %c8_i32 : i32 loc(#loc32) - %127 = arith.remsi %122, %126 : i32 loc(#loc33) - %128 = arith.addi %124, %127 : i32 loc(#loc34) - %129 = arith.remsi %122, %14 : i32 loc(#loc35) - %130 = arith.divsi %129, %126 : i32 loc(#loc36) - %131 = arith.muli %128, %c128_i32 : i32 loc(#loc37) - %132 = arith.muli %130, %c256_i32 : i32 loc(#loc38) - %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) - %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + %123 = arith.divsi %122, %14 : i32 loc(#loc23) + %124 = arith.muli %123, %c8_i32 : i32 loc(#loc24) + %125 = arith.subi %2, %124 : i32 loc(#loc25) + %126 = arith.minsi %125, %c8_i32 : i32 loc(#loc26) + %127 = arith.remsi %122, %126 : i32 loc(#loc27) + %128 = arith.addi %124, %127 : i32 loc(#loc28) + %129 = arith.remsi %122, %14 : i32 loc(#loc29) + %130 = arith.divsi %129, %126 : i32 loc(#loc30) + %131 = arith.muli %128, %c128_i32 : i32 loc(#loc31) + %132 = arith.muli %130, %c256_i32 : i32 loc(#loc32) + %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) + %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) + %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) + %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) + %137 = arith.cmpi slt, %134, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) + %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc36) + %139 = arith.cmpi slt, %136, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) + %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) + scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) } else { - scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) + scf.yield %0, %29, %31, %40, %43 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) } loc(#loc54) %83 = arith.muli %80, %c64_i32 : i32 loc(#loc56) %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) - %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) - %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) - %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> loc(#loc46) - %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) - %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) - %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) - %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> loc(#loc22) - %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> loc(#loc48) - %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) - %102 = arith.subi %arg5, %83 : i32 loc(#loc60) + %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc39) + %89 = arith.muli %88, %45 : tensor<128x1xi32, #blocked1> loc(#loc40) + %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc41) + %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) + %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) + %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> loc(#loc42) + %94 = tt.addptr %51, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc43) + %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc44) + %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc45) + %97 = arith.muli %96, %55 : tensor<1x256xi32, #blocked> loc(#loc46) + %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) + %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) + %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> loc(#loc47) + %101 = tt.addptr %60, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc48) + %102 = arith.subi %arg5, %83 : i32 loc(#loc58) %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> loc(#loc49) - %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) - %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) - %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> loc(#loc28) - %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %110 = ttg.async_commit_group %109 loc(#loc26) + %104 = arith.cmpi slt, %47, %103 : tensor<1x64xi32, #blocked1> loc(#loc49) + %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc19) + %106 = ttg.memdesc_subview %20[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) + %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc21) + %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> loc(#loc21) + %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) + %110 = ttg.async_commit_group %109 loc(#loc19) %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> loc(#loc50) - %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) - %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) - %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> loc(#loc28) - %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %118 = ttg.async_commit_group %117 loc(#loc27) - %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> ( - i32, i32, i32, i32, - tensor<128x256xf32, #mma>, - tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, - tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, - i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { - %122 = arith.subi %19, %c2_i32 : i32 loc(#loc28) - %123 = arith.cmpi slt, %arg9, %122 : i32 loc(#loc28) - %124 = arith.cmpi eq, %arg10, %20 : i32 loc(#loc52) - %125 = arith.addi %arg10, %c1_i32 : i32 loc(#loc61) + %112 = arith.cmpi slt, %53, %111 : tensor<64x1xi32, #blocked> loc(#loc50) + %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc20) + %114 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) + %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> loc(#loc21) + %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> loc(#loc21) + %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) + %118 = ttg.async_commit_group %117 loc(#loc20) + %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %29, %arg25 = %82#1, %arg26 = %31, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { + %122 = arith.subi %19, %c2_i32 : i32 loc(#loc21) + %123 = arith.cmpi slt, %arg9, %122 : i32 loc(#loc21) + %124 = arith.cmpi eq, %arg10, %23 : i32 loc(#loc52) + %125 = arith.addi %arg10, %c1_i32 : i32 loc(#loc59) %126 = arith.select %124, %c0_i32, %125 : i32 loc(#loc51) %127 = arith.cmpi eq, %126, %c0_i32 : i32 loc(#loc53) %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { %178 = arith.addi %arg11, %c132_i32 : i32 loc(#loc55) - %179 = arith.divsi %178, %14 : i32 loc(#loc29) - %180 = arith.muli %179, %c8_i32 : i32 loc(#loc30) - %181 = arith.subi %2, %180 : i32 loc(#loc31) - %182 = arith.minsi %181, %c8_i32 : i32 loc(#loc32) - %183 = arith.remsi %178, %182 : i32 loc(#loc33) - %184 = arith.addi %180, %183 : i32 loc(#loc34) - %185 = arith.remsi %178, %14 : i32 loc(#loc35) - %186 = arith.divsi %185, %182 : i32 loc(#loc36) - %187 = arith.muli %184, %c128_i32 : i32 loc(#loc37) - %188 = arith.muli %186, %c256_i32 : i32 loc(#loc38) - %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc39) - %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc40) - %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc41) - %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc42) - %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc43) - %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) - scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc44) + %179 = arith.divsi %178, %14 : i32 loc(#loc23) + %180 = arith.muli %179, %c8_i32 : i32 loc(#loc24) + %181 = arith.subi %2, %180 : i32 loc(#loc25) + %182 = arith.minsi %181, %c8_i32 : i32 loc(#loc26) + %183 = arith.remsi %178, %182 : i32 loc(#loc27) + %184 = arith.addi %180, %183 : i32 loc(#loc28) + %185 = arith.remsi %178, %14 : i32 loc(#loc29) + %186 = arith.divsi %185, %182 : i32 loc(#loc30) + %187 = arith.muli %184, %c128_i32 : i32 loc(#loc31) + %188 = arith.muli %186, %c256_i32 : i32 loc(#loc32) + %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) + %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) + %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) + %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) + %193 = arith.cmpi slt, %190, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) + %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc36) + %195 = arith.cmpi slt, %192, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) + %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) + scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) } else { scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) } loc(#loc54) - %129 = arith.addi %arg19, %c1_i32 : i32 loc(#loc28) - %130 = arith.cmpi slt, %129, %c3_i32 : i32 loc(#loc28) - %131 = arith.select %130, %129, %c0_i32 : i32 loc(#loc28) - %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %133 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc26) - %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc62) - %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) - %137 = arith.addi %arg18, %c1_i32 : i32 loc(#loc28) - %138 = arith.cmpi slt, %137, %c3_i32 : i32 loc(#loc28) - %139 = arith.select %138, %137, %c0_i32 : i32 loc(#loc28) - %140 = arith.muli %126, %c64_i32 : i32 loc(#loc56) - %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc45) - %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> loc(#loc20) - %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc58) - %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc46) - %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> loc(#loc46) - %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc21) - %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc59) - %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc47) - %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> loc(#loc22) - %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc48) - %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> loc(#loc48) - %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc23) - %159 = arith.subi %arg5, %140 : i32 loc(#loc60) - %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> loc(#loc49) - %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc26) - %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc28) - %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> loc(#loc28) - %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc26) - %167 = ttg.async_commit_group %166 loc(#loc26) - %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> loc(#loc50) - %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc27) - %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> loc(#loc28) - %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> loc(#loc28) - %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc27) - %175 = ttg.async_commit_group %174 loc(#loc27) - %176 = arith.cmpi eq, %arg22, %20 : i32 loc(#loc63) - %177 = arith.cmpi ne, %arg22, %20 : i32 loc(#loc87) - scf.if %176 { - %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc62) - %179 = arith.muli %arg24, %c128_i32 : i32 loc(#loc65) - %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) - %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc66) - %182 = arith.muli %arg26, %c256_i32 : i32 loc(#loc67) - %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) - %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc68) - %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc69) - %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc70) - %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> loc(#loc70) - %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc71) - %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc71) - %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc72) - %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc73) - %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc73) - %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc73) - %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc74) - %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> loc(#loc74) - %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc75) - %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> loc(#loc75) - %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) - %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc76) - %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> loc(#loc76) - %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc77) - %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc78) - tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc78) - } loc(#loc64) - scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 loc(#loc28) - } loc(#loc28) - %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc28) - %121 = ttg.async_wait {num = 0 : i32} loc(#loc28) - ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc28) - ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc28) - tt.return loc(#loc79) + %129 = arith.addi %arg19, %c1_i32 : i32 loc(#loc21) + %130 = arith.cmpi slt, %129, %c3_i32 : i32 loc(#loc21) + %131 = arith.select %130, %129, %c0_i32 : i32 loc(#loc21) + %132 = ttg.memdesc_subview %20[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) + %133 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc19) + %134 = ttg.memdesc_subview %21[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) + %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc60) + %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc60) + %137 = arith.cmpi ne, %arg22, %23 : i32 loc(#loc85) + %138 = arith.addi %arg18, %c1_i32 : i32 loc(#loc21) + %139 = arith.cmpi slt, %138, %c3_i32 : i32 loc(#loc21) + %140 = arith.select %139, %138, %c0_i32 : i32 loc(#loc21) + %141 = arith.muli %126, %c64_i32 : i32 loc(#loc56) + %142 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %143 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %144 = arith.addi %142, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) + %145 = arith.addi %143, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) + %146 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc39) + %147 = arith.muli %146, %45 : tensor<128x1xi32, #blocked1> loc(#loc40) + %148 = tt.expand_dims %144 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc41) + %149 = tt.broadcast %147 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) + %150 = tt.broadcast %148 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) + %151 = arith.addi %149, %150 : tensor<128x64xi32, #blocked1> loc(#loc42) + %152 = tt.addptr %51, %151 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc43) + %153 = tt.expand_dims %145 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc44) + %154 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc45) + %155 = arith.muli %154, %55 : tensor<1x256xi32, #blocked> loc(#loc46) + %156 = tt.broadcast %153 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) + %157 = tt.broadcast %155 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) + %158 = arith.addi %156, %157 : tensor<64x256xi32, #blocked> loc(#loc47) + %159 = tt.addptr %60, %158 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc48) + %160 = arith.subi %arg5, %141 : i32 loc(#loc58) + %161 = tt.splat %160 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) + %162 = arith.cmpi slt, %47, %161 : tensor<1x64xi32, #blocked1> loc(#loc49) + %163 = tt.broadcast %162 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc19) + %164 = ttg.memdesc_subview %20[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) + %165 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc21) + %166 = arith.andi %165, %163 : tensor<128x64xi1, #blocked1> loc(#loc21) + %167 = ttg.async_copy_global_to_local %152, %164 mask %166 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) + %168 = ttg.async_commit_group %167 loc(#loc19) + %169 = tt.splat %160 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) + %170 = arith.cmpi slt, %53, %169 : tensor<64x1xi32, #blocked> loc(#loc50) + %171 = tt.broadcast %170 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc20) + %172 = ttg.memdesc_subview %21[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) + %173 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> loc(#loc21) + %174 = arith.andi %173, %171 : tensor<64x256xi1, #blocked> loc(#loc21) + %175 = ttg.async_copy_global_to_local %159, %172 mask %174 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) + %176 = ttg.async_commit_group %175 loc(#loc20) + %177 = arith.cmpi eq, %arg22, %23 : i32 loc(#loc61) + scf.if %177 { + %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc60) + %179 = arith.muli %arg24, %c128_i32 : i32 loc(#loc63) + %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc64) + %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc64) + %182 = arith.muli %arg26, %c256_i32 : i32 loc(#loc65) + %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc66) + %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc66) + %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc67) + %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc68) + %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> loc(#loc68) + %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc69) + %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc69) + %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc70) + %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc71) + %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc71) + %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc71) + %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc72) + %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> loc(#loc72) + %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc73) + %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> loc(#loc73) + %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc74) + %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc74) + %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> loc(#loc74) + %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc75) + %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc76) + tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc76) + } loc(#loc62) + scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %137, %140, %131, %arg21, %176, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 loc(#loc21) + } loc(#loc21) + %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc21) + %121 = ttg.async_wait {num = 0 : i32} loc(#loc21) + ttg.local_dealloc %20 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc21) + ttg.local_dealloc %21 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc21) + tt.return loc(#loc77) } loc(#loc) } loc(#loc) #loc1 = loc(unknown) -#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) +#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:30) #loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) -#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) +#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) #loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) -#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) -#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) -#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) -#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":172:32) -#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:31) -#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:19) -#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:7) -#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:24) -#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:35) -#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":181:38) -#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:27) -#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:27) -#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:32) -#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:38) -#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:45) -#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:26) -#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:75) -#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:26) -#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:49) -#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:49) -#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:20) -#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:20) -#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:22) -#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:34) -#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) -#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:43) -#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:56) -#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:45) -#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:35) -#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:31) -#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:52) -#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":200:30) -#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":201:30) -#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":202:32) -#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:32) -#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:41) -#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:53) -#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:41) -#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:53) -#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:34) -#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:57) -#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:64) -#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:56) -#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:60) -#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:60) -#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:44) -#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:28) -#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:17) -#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:11) -#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:23) -#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:22) -#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:37) -#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:64) -#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:33) -#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":212:64) -#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:49) -#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:35) -#loc63 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:17) -#loc64 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":216:11) -#loc65 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:30) -#loc66 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:45) -#loc67 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:30) -#loc68 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:45) -#loc69 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:49) -#loc70 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:41) -#loc71 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:29) -#loc72 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:80) -#loc73 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:60) -#loc74 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:41) -#loc75 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:66) -#loc76 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:47) -#loc77 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":224:35) -#loc78 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":225:29) -#loc79 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":190:4) -#loc80 = loc(callsite(#loc3 at #loc4)) -#loc81 = loc(callsite(#loc5 at #loc4)) -#loc82 = loc(callsite(#loc3 at #loc6)) -#loc83 = loc(callsite(#loc5 at #loc6)) -#loc84 = loc(callsite(#loc3 at #loc7)) -#loc85 = loc(callsite(#loc5 at #loc7)) -#loc86 = loc(fused[#loc51, #loc52]) -#loc87 = loc(fused[#loc64, #loc63]) +#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:27) +#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:25) +#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":171:28) +#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:32) +#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:31) +#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:19) +#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:7) +#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:24) +#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:35) +#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":182:38) +#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:27) +#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:27) +#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:32) +#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:20) +#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:20) +#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:22) +#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:38) +#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:34) +#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:37) +#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:43) +#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:56) +#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:45) +#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:35) +#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:31) +#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:52) +#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":201:30) +#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":202:30) +#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:32) +#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:32) +#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:41) +#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:53) +#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":206:41) +#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":206:53) +#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:34) +#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:45) +#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:64) +#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:57) +#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:26) +#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:33) +#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:64) +#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:75) +#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:56) +#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:26) +#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:60) +#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:60) +#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:44) +#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:28) +#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:17) +#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:11) +#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:23) +#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:22) +#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:37) +#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:64) +#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:49) +#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":215:35) +#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:17) +#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:11) +#loc63 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:30) +#loc64 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:45) +#loc65 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:30) +#loc66 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:45) +#loc67 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:49) +#loc68 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:41) +#loc69 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:29) +#loc70 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:80) +#loc71 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:60) +#loc72 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":221:41) +#loc73 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":221:66) +#loc74 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":221:47) +#loc75 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":225:35) +#loc76 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":226:29) +#loc77 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:4) +#loc78 = loc(callsite(#loc3 at #loc4)) +#loc79 = loc(callsite(#loc5 at #loc4)) +#loc80 = loc(callsite(#loc3 at #loc6)) +#loc81 = loc(callsite(#loc5 at #loc6)) +#loc82 = loc(callsite(#loc3 at #loc7)) +#loc83 = loc(callsite(#loc5 at #loc7)) +#loc84 = loc(fused[#loc51, #loc52]) +#loc85 = loc(fused[#loc60, #loc61]) diff --git a/orig2.mlir b/orig2.mlir index 63cc3d385e0b..69b0e81760e4 100644 --- a/orig2.mlir +++ b/orig2.mlir @@ -51,66 +51,66 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> %19 = arith.muli %6, %11 : i32 - %20 = arith.subi %6, %c1_i32 : i32 - %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - %29 = arith.cmpi sgt, %19, %c0_i32 : i32 - %30 = arith.divsi %0, %14 : i32 - %31 = arith.muli %30, %c8_i32 : i32 - %32 = arith.subi %2, %31 : i32 - %33 = arith.minsi %32, %c8_i32 : i32 - %34 = arith.remsi %0, %33 : i32 - %35 = arith.addi %31, %34 : i32 - %36 = arith.remsi %0, %14 : i32 - %37 = arith.divsi %36, %33 : i32 - %38 = arith.muli %35, %c128_i32 : i32 - %39 = arith.muli %37, %c256_i32 : i32 - %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> - %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> - %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> - %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> - %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %20 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + %21 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %22 = arith.cmpi sgt, %19, %c0_i32 : i32 + %23 = arith.subi %6, %c1_i32 : i32 + %24 = arith.divsi %0, %14 : i32 + %25 = arith.muli %24, %c8_i32 : i32 + %26 = arith.subi %2, %25 : i32 + %27 = arith.minsi %26, %c8_i32 : i32 + %28 = arith.remsi %0, %27 : i32 + %29 = arith.addi %25, %28 : i32 + %30 = arith.remsi %0, %14 : i32 + %31 = arith.divsi %30, %27 : i32 + %32 = arith.muli %29, %c128_i32 : i32 + %33 = arith.muli %31, %c256_i32 : i32 + %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %35 = arith.addi %34, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %37 = arith.addi %36, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %38 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %39 = arith.cmpi slt, %35, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %40 = arith.select %39, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %41 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %42 = arith.cmpi slt, %37, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %43 = arith.select %42, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %44 = tt.expand_dims %40 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %45 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %46 = arith.muli %44, %45 : tensor<128x1xi32, #blocked1> + %47 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %48 = tt.broadcast %46 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %49 = tt.broadcast %47 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %50 = arith.addi %48, %49 : tensor<128x64xi32, #blocked1> + %51 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %52 = tt.addptr %51, %50 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %53 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %54 = tt.expand_dims %43 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %55 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %56 = arith.muli %54, %55 : tensor<1x256xi32, #blocked> + %57 = tt.broadcast %53 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %58 = tt.broadcast %56 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %59 = arith.addi %57, %58 : tensor<64x256xi32, #blocked> + %60 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %61 = tt.addptr %60, %59 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> - %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> + %63 = arith.cmpi slt, %47, %62 : tensor<1x64xi32, #blocked1> %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> + %65 = ttg.memdesc_subview %20[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %66 = tt.splat %22 : i1 -> tensor<128x64xi1, #blocked1> %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> - %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %68 = ttg.async_copy_global_to_local %52, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> %69 = ttg.async_commit_group %68 %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> - %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> + %71 = arith.cmpi slt, %53, %70 : tensor<64x1xi32, #blocked> %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> + %73 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %74 = tt.splat %22 : i1 -> tensor<64x256xi1, #blocked> %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> %77 = ttg.async_commit_group %76 %78 = arith.cmpi sgt, %19, %c1_i32 : i32 - %79 = arith.cmpi ne, %20, %c0_i32 : i32 + %79 = arith.cmpi ne, %23, %c0_i32 : i32 %80 = arith.extui %79 : i1 to i32 %81 = arith.cmpi eq, %80, %c0_i32 : i32 %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { @@ -129,13 +129,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %137 = arith.cmpi slt, %134, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %139 = arith.cmpi slt, %136, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> } else { - scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %0, %29, %31, %40, %43 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> } %83 = arith.muli %80, %c64_i32 : i32 %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> @@ -143,40 +143,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> + %89 = arith.muli %88, %45 : tensor<128x1xi32, #blocked1> %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> - %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %94 = tt.addptr %51, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> + %97 = arith.muli %96, %55 : tensor<1x256xi32, #blocked> %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> - %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %101 = tt.addptr %60, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %102 = arith.subi %arg5, %83 : i32 %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> - %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> + %104 = arith.cmpi slt, %47, %103 : tensor<1x64xi32, #blocked1> %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %106 = ttg.memdesc_subview %20[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> %110 = ttg.async_commit_group %109 %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> - %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> + %112 = arith.cmpi slt, %53, %111 : tensor<64x1xi32, #blocked> %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %114 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> %118 = ttg.async_commit_group %117 - %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { + %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %29, %arg25 = %82#1, %arg26 = %31, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { %122 = arith.subi %19, %c2_i32 : i32 %123 = arith.cmpi slt, %arg9, %122 : i32 - %124 = arith.cmpi eq, %arg10, %20 : i32 + %124 = arith.cmpi eq, %arg10, %23 : i32 %125 = arith.addi %arg10, %c1_i32 : i32 %126 = arith.select %124, %c0_i32, %125 : i32 %127 = arith.cmpi eq, %126, %c0_i32 : i32 @@ -196,9 +196,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %193 = arith.cmpi slt, %190, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %195 = arith.cmpi slt, %192, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> } else { @@ -207,53 +207,53 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %129 = arith.addi %arg19, %c1_i32 : i32 %130 = arith.cmpi slt, %129, %c3_i32 : i32 %131 = arith.select %130, %129, %c0_i32 : i32 - %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %132 = ttg.memdesc_subview %20[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> %133 = ttg.async_wait %arg20 {num = 2 : i32} - %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %134 = ttg.memdesc_subview %21[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %137 = arith.addi %arg18, %c1_i32 : i32 - %138 = arith.cmpi slt, %137, %c3_i32 : i32 - %139 = arith.select %138, %137, %c0_i32 : i32 - %140 = arith.muli %126, %c64_i32 : i32 - %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> - %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> - %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> - %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> - %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %159 = arith.subi %arg5, %140 : i32 - %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> - %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> - %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> - %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> - %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %167 = ttg.async_commit_group %166 - %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> - %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> - %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> - %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> - %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %175 = ttg.async_commit_group %174 - %176 = arith.cmpi eq, %arg22, %20 : i32 - %177 = arith.cmpi ne, %arg22, %20 : i32 - scf.if %176 { + %137 = arith.cmpi ne, %arg22, %23 : i32 + %138 = arith.addi %arg18, %c1_i32 : i32 + %139 = arith.cmpi slt, %138, %c3_i32 : i32 + %140 = arith.select %139, %138, %c0_i32 : i32 + %141 = arith.muli %126, %c64_i32 : i32 + %142 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %143 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %144 = arith.addi %142, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %145 = arith.addi %143, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %146 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %147 = arith.muli %146, %45 : tensor<128x1xi32, #blocked1> + %148 = tt.expand_dims %144 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %149 = tt.broadcast %147 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %150 = tt.broadcast %148 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %151 = arith.addi %149, %150 : tensor<128x64xi32, #blocked1> + %152 = tt.addptr %51, %151 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %153 = tt.expand_dims %145 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %154 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %155 = arith.muli %154, %55 : tensor<1x256xi32, #blocked> + %156 = tt.broadcast %153 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %157 = tt.broadcast %155 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %158 = arith.addi %156, %157 : tensor<64x256xi32, #blocked> + %159 = tt.addptr %60, %158 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %160 = arith.subi %arg5, %141 : i32 + %161 = tt.splat %160 : i32 -> tensor<1x64xi32, #blocked1> + %162 = arith.cmpi slt, %47, %161 : tensor<1x64xi32, #blocked1> + %163 = tt.broadcast %162 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %164 = ttg.memdesc_subview %20[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %165 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> + %166 = arith.andi %165, %163 : tensor<128x64xi1, #blocked1> + %167 = ttg.async_copy_global_to_local %152, %164 mask %166 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %168 = ttg.async_commit_group %167 + %169 = tt.splat %160 : i32 -> tensor<64x1xi32, #blocked> + %170 = arith.cmpi slt, %53, %169 : tensor<64x1xi32, #blocked> + %171 = tt.broadcast %170 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %172 = ttg.memdesc_subview %21[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %173 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> + %174 = arith.andi %173, %171 : tensor<64x256xi1, #blocked> + %175 = ttg.async_copy_global_to_local %159, %172 mask %174 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %176 = ttg.async_commit_group %175 + %177 = arith.cmpi eq, %arg22, %23 : i32 + scf.if %177 { %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> %179 = arith.muli %arg24, %c128_i32 : i32 %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> @@ -281,12 +281,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> } - scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 + scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %137, %140, %131, %arg21, %176, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 } %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> %121 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + ttg.local_dealloc %20 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %21 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> tt.return } } diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 5499e035915b..32a7848a3182 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -263,6 +263,22 @@ def matmul_persistent_fused(a, b): num_stages=configs[dtype]["num_stages"], # num_warps=configs[dtype]["num_warps"], # ) + #kernel = matmul_kernel_persistent_fused.warmup( + # a, b, c, # + # M, N, K, # + # a.stride(0), a.stride(1), # + # b.stride(0), b.stride(1), # + # c.stride(0), c.stride(1), # + # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + # BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + # BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + # GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + # NUM_SMS=NUM_SMS, # + # num_stages=configs[dtype]["num_stages"], # + # num_warps=configs[dtype]["num_warps"], # + # grid=grid + #) + #print(kernel.asm["ttgir"]) return c @@ -333,185 +349,182 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c2_i32 = arith.constant 2 : i32 %c3_i32 = arith.constant 3 : i32 - %false = arith.constant false + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %c256_i32 = arith.constant 256 : i32 %c128_i32 = arith.constant 128 : i32 - %c0_i32 = arith.constant 0 : i32 %c8_i32 = arith.constant 8 : i32 - %c-1_i32 = arith.constant -1 : i32 - %c1_i32 = arith.constant 1 : i32 - %c132_i32 = arith.constant 132 : i32 %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> %c64_i32 = arith.constant 64 : i32 + %c132_i32 = arith.constant 132 : i32 %c127_i32 = arith.constant 127 : i32 %c255_i32 = arith.constant 255 : i32 %c63_i32 = arith.constant 63 : i32 %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.addi %arg5, %c63_i32 : i32 - %6 = arith.divsi %5, %c64_i32 : i32 - %7 = arith.muli %2, %4 : i32 - %8 = arith.divsi %7, %c132_i32 : i32 - %9 = arith.remsi %7, %c132_i32 : i32 - %10 = arith.cmpi slt, %0, %9 : i32 - %11 = scf.if %10 -> (i32) { - %122 = arith.addi %8, %c1_i32 : i32 - scf.yield %122 : i32 - } else { - scf.yield %8 : i32 - } - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %14 = arith.muli %4, %c8_i32 : i32 - %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %19 = arith.muli %6, %11 : i32 - %20 = arith.subi %6, %c1_i32 : i32 - %21 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %22 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %23 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %25 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %26 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %27 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - %28 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - %29 = arith.cmpi sgt, %19, %c0_i32 : i32 - %30 = arith.divsi %0, %14 : i32 - %31 = arith.muli %30, %c8_i32 : i32 - %32 = arith.subi %2, %31 : i32 - %33 = arith.minsi %32, %c8_i32 : i32 - %34 = arith.remsi %0, %33 : i32 - %35 = arith.addi %31, %34 : i32 - %36 = arith.remsi %0, %14 : i32 - %37 = arith.divsi %36, %33 : i32 - %38 = arith.muli %35, %c128_i32 : i32 - %39 = arith.muli %37, %c256_i32 : i32 - %40 = tt.splat %38 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %41 = arith.addi %40, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %42 = tt.splat %39 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %43 = arith.addi %42, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %44 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %45 = arith.cmpi slt, %41, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %46 = arith.select %45, %41, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %47 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %48 = arith.cmpi slt, %43, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %49 = arith.select %48, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %50 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %51 = arith.muli %50, %21 : tensor<128x1xi32, #blocked1> - %52 = tt.broadcast %51 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %53 = tt.broadcast %25 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %54 = arith.addi %52, %53 : tensor<128x64xi32, #blocked1> - %55 = tt.addptr %22, %54 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %56 = tt.expand_dims %49 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %57 = arith.muli %56, %23 : tensor<1x256xi32, #blocked> - %58 = tt.broadcast %26 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %59 = tt.broadcast %57 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %60 = arith.addi %58, %59 : tensor<64x256xi32, #blocked> - %61 = tt.addptr %24, %60 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> - %63 = arith.cmpi slt, %25, %62 : tensor<1x64xi32, #blocked1> - %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %65 = ttg.memdesc_subview %27[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %66 = tt.splat %29 : i1 -> tensor<128x64xi1, #blocked1> - %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> - %68 = ttg.async_copy_global_to_local %55, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %69 = ttg.async_commit_group %68 - %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> - %71 = arith.cmpi slt, %26, %70 : tensor<64x1xi32, #blocked> - %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %73 = ttg.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %74 = tt.splat %29 : i1 -> tensor<64x256xi1, #blocked> - %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> - %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %77 = ttg.async_commit_group %76 - %78 = arith.cmpi sgt, %19, %c1_i32 : i32 - %79 = arith.cmpi ne, %20, %c0_i32 : i32 + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %5 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %6 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = arith.addi %arg3, %c127_i32 : i32 + %8 = arith.divsi %7, %c128_i32 : i32 + %9 = arith.addi %arg4, %c255_i32 : i32 + %10 = arith.divsi %9, %c256_i32 : i32 + %11 = arith.addi %arg5, %c63_i32 : i32 + %12 = arith.divsi %11, %c64_i32 : i32 + %13 = arith.muli %8, %10 : i32 + %14 = arith.muli %10, %c8_i32 : i32 + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = arith.subi %13, %0 : i32 + %18 = arith.ceildivsi %17, %c132_i32 : i32 + %19 = arith.maxsi %12, %c1_i32 : i32 + %20 = arith.muli %18, %19 : i32 + %21 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %23 = arith.cmpi sgt, %20, %c0_i32 : i32 + %24 = arith.divsi %0, %14 : i32 + %25 = arith.muli %24, %c8_i32 : i32 + %26 = arith.subi %8, %25 : i32 + %27 = arith.minsi %26, %c8_i32 : i32 + %28 = arith.remsi %0, %27 : i32 + %29 = arith.addi %25, %28 : i32 + %30 = arith.remsi %0, %14 : i32 + %31 = arith.divsi %30, %27 : i32 + %32 = arith.muli %29, %c128_i32 : i32 + %33 = arith.muli %31, %c256_i32 : i32 + %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %35 = arith.addi %34, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %37 = arith.addi %36, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %38 = arith.cmpi slt, %35, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %39 = arith.select %38, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %40 = arith.cmpi slt, %37, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %41 = arith.select %40, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %43 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %44 = arith.muli %42, %43 : tensor<128x1xi32, #blocked1> + %45 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %46 = tt.broadcast %44 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %47 = tt.broadcast %45 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %48 = arith.addi %46, %47 : tensor<128x64xi32, #blocked1> + %49 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %50 = tt.addptr %49, %48 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %51 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %52 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %53 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> + %54 = arith.muli %52, %53 : tensor<1x256xi32, #blocked> + %55 = tt.broadcast %51 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %56 = tt.broadcast %54 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %57 = arith.addi %55, %56 : tensor<64x256xi32, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %59 = tt.addptr %58, %57 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %60 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %61 = arith.cmpi slt, %45, %60 : tensor<1x64xi32, #blocked1> + %62 = tt.broadcast %61 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %63 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %64 = tt.splat %23 : i1 -> tensor<128x64xi1, #blocked1> + %65 = arith.andi %64, %62 : tensor<128x64xi1, #blocked1> + %66 = ttg.async_copy_global_to_local %50, %63 mask %65 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %67 = ttg.async_commit_group %66 + %68 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> + %69 = arith.cmpi slt, %51, %68 : tensor<64x1xi32, #blocked> + %70 = tt.broadcast %69 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %71 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %72 = tt.splat %23 : i1 -> tensor<64x256xi1, #blocked> + %73 = arith.andi %72, %70 : tensor<64x256xi1, #blocked> + %74 = ttg.async_copy_global_to_local %59, %71 mask %73 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %75 = ttg.async_commit_group %74 + %76 = arith.cmpi sgt, %20, %c1_i32 : i32 + %77 = arith.remsi %c1_i32, %19 : i32 + %78 = arith.cmpi eq, %77, %c0_i32 : i32 + %79 = arith.cmpi ne, %77, %c0_i32 : i32 %80 = arith.extui %79 : i1 to i32 - %81 = arith.cmpi eq, %80, %c0_i32 : i32 - %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %122 = arith.addi %0, %c132_i32 : i32 - %123 = arith.divsi %122, %14 : i32 - %124 = arith.muli %123, %c8_i32 : i32 - %125 = arith.subi %2, %124 : i32 - %126 = arith.minsi %125, %c8_i32 : i32 - %127 = arith.remsi %122, %126 : i32 - %128 = arith.addi %124, %127 : i32 - %129 = arith.remsi %122, %14 : i32 - %130 = arith.divsi %129, %126 : i32 - %131 = arith.muli %128, %c128_i32 : i32 - %132 = arith.muli %130, %c256_i32 : i32 - %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %137 = arith.cmpi slt, %134, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %139 = arith.cmpi slt, %136, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %81:5 = scf.if %78 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %121 = arith.addi %0, %c132_i32 : i32 + %122 = arith.divsi %121, %14 : i32 + %123 = arith.muli %122, %c8_i32 : i32 + %124 = arith.subi %8, %123 : i32 + %125 = arith.minsi %124, %c8_i32 : i32 + %126 = arith.remsi %121, %125 : i32 + %127 = arith.addi %123, %126 : i32 + %128 = arith.remsi %121, %14 : i32 + %129 = arith.divsi %128, %125 : i32 + %130 = arith.muli %127, %c128_i32 : i32 + %131 = arith.muli %129, %c256_i32 : i32 + %132 = tt.splat %130 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %133 = arith.addi %132, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %134 = tt.splat %131 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %135 = arith.addi %134, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %136 = arith.cmpi slt, %133, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %137 = arith.select %136, %133, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %138 = arith.cmpi slt, %135, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %139 = arith.select %138, %135, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %130, %131, %137, %139, %121 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 } else { - scf.yield %0, %35, %37, %46, %49 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %32, %33, %39, %41, %0 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 } - %83 = arith.muli %80, %c64_i32 : i32 - %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %89 = arith.muli %88, %21 : tensor<128x1xi32, #blocked1> - %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> - %94 = tt.addptr %22, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %97 = arith.muli %96, %23 : tensor<1x256xi32, #blocked> - %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> - %101 = tt.addptr %24, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %102 = arith.subi %arg5, %83 : i32 - %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> - %104 = arith.cmpi slt, %25, %103 : tensor<1x64xi32, #blocked1> - %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %106 = ttg.memdesc_subview %27[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> - %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> - %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %110 = ttg.async_commit_group %109 - %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> - %112 = arith.cmpi slt, %26, %111 : tensor<64x1xi32, #blocked> - %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %114 = ttg.memdesc_subview %28[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> - %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> - %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %118 = ttg.async_commit_group %117 - %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %35, %arg25 = %82#1, %arg26 = %37, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { - %122 = arith.subi %19, %c2_i32 : i32 - %123 = arith.cmpi slt, %arg9, %122 : i32 - %124 = arith.cmpi eq, %arg10, %20 : i32 - %125 = arith.addi %arg10, %c1_i32 : i32 - %126 = arith.select %124, %c0_i32, %125 : i32 - %127 = arith.cmpi eq, %126, %c0_i32 : i32 - %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %178 = arith.addi %arg11, %c132_i32 : i32 + %82 = arith.muli %80, %c64_i32 : i32 + %83 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %84 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %85 = arith.addi %83, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %86 = arith.addi %84, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %87 = tt.expand_dims %81#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %88 = arith.muli %87, %43 : tensor<128x1xi32, #blocked1> + %89 = tt.expand_dims %85 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %90 = tt.broadcast %88 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %91 = tt.broadcast %89 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %92 = arith.addi %90, %91 : tensor<128x64xi32, #blocked1> + %93 = tt.addptr %49, %92 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %94 = tt.expand_dims %86 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %95 = tt.expand_dims %81#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %96 = arith.muli %95, %53 : tensor<1x256xi32, #blocked> + %97 = tt.broadcast %94 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %98 = tt.broadcast %96 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> + %99 = arith.addi %97, %98 : tensor<64x256xi32, #blocked> + %100 = tt.addptr %58, %99 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %101 = arith.subi %arg5, %82 : i32 + %102 = tt.splat %101 : i32 -> tensor<1x64xi32, #blocked1> + %103 = arith.cmpi slt, %45, %102 : tensor<1x64xi32, #blocked1> + %104 = tt.broadcast %103 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> + %105 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %106 = tt.splat %76 : i1 -> tensor<128x64xi1, #blocked1> + %107 = arith.andi %106, %104 : tensor<128x64xi1, #blocked1> + %108 = ttg.async_copy_global_to_local %93, %105 mask %107 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> + %109 = ttg.async_commit_group %108 + %110 = tt.splat %101 : i32 -> tensor<64x1xi32, #blocked> + %111 = arith.cmpi slt, %51, %110 : tensor<64x1xi32, #blocked> + %112 = tt.broadcast %111 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> + %113 = ttg.memdesc_subview %22[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %114 = tt.splat %76 : i1 -> tensor<64x256xi1, #blocked> + %115 = arith.andi %114, %112 : tensor<64x256xi1, #blocked> + %116 = ttg.async_copy_global_to_local %100, %113 mask %115 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> + %117 = ttg.async_commit_group %116 + %lol = arith.subi %12, %c1_i32 : i32 + %118:16 = scf.for %arg9 = %c0_i32 to %20 step %c1_i32 iter_args( + %arg10 = %81#4, %arg11 = %cst_3, %arg12 = %81#0, %arg13 = %81#1, + %arg14 = %81#2, %arg15 = %81#3, %arg16 = %c1_i32, %arg17 = %c-1_i32, + %arg18 = %80, %arg19 = %c0_i32, %arg21 = %75, %arg22 = %117, %arg23 = %32, %arg24 = %81#0, %arg25 = %33, %arg26 = %81#1) -> (i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i32 { + %121 = arith.subi %20, %c2_i32 : i32 + %122 = arith.cmpi slt, %arg9, %121 : i32 + %rollover = arith.cmpi eq, %arg18, %lol : i32 + %123 = arith.addi %arg18, %c1_i32 : i32 + %126 = arith.select %rollover, %c0_i32, %123 : i32 + %125 = arith.cmpi eq, %126, %c0_i32 : i32 + %127:5 = scf.if %125 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { + %178 = arith.addi %arg10, %c132_i32 : i32 %179 = arith.divsi %178, %14 : i32 %180 = arith.muli %179, %c8_i32 : i32 - %181 = arith.subi %2, %180 : i32 + %181 = arith.subi %8, %180 : i32 %182 = arith.minsi %181, %c8_i32 : i32 %183 = arith.remsi %178, %182 : i32 %184 = arith.addi %180, %183 : i32 @@ -520,103 +533,106 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # %187 = arith.muli %184, %c128_i32 : i32 %188 = arith.muli %186, %c256_i32 : i32 %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %190 = arith.addi %189, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %193 = arith.cmpi slt, %190, %44 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %192 = arith.addi %191, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %193 = arith.cmpi slt, %190, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %195 = arith.cmpi slt, %192, %47 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %195 = arith.cmpi slt, %192, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %187, %188, %194, %196, %178 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 } else { - scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.yield %arg12, %arg13, %arg14, %arg15, %arg10 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 } - %129 = arith.addi %arg19, %c1_i32 : i32 - %130 = arith.cmpi slt, %129, %c3_i32 : i32 - %131 = arith.select %130, %129, %c0_i32 : i32 - %132 = ttg.memdesc_subview %27[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %133 = ttg.async_wait %arg20 {num = 2 : i32} - %134 = ttg.memdesc_subview %28[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> + %128 = arith.addi %arg17, %c1_i32 : i32 + %129 = arith.cmpi slt, %128, %c3_i32 : i32 + %130 = arith.select %129, %128, %c0_i32 : i32 + %131 = arith.cmpi ne, %arg19, %c0_i32 : i32 + %132 = ttg.memdesc_subview %21[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %133 = ttg.async_wait %arg21 {num = 2 : i32} + %134 = ttg.memdesc_subview %22[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %135 = ttng.warp_group_dot %132, %134, %arg11, %131 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %137 = arith.addi %arg18, %c1_i32 : i32 + %137 = arith.addi %arg16, %c1_i32 : i32 %138 = arith.cmpi slt, %137, %c3_i32 : i32 %139 = arith.select %138, %137, %c0_i32 : i32 %140 = arith.muli %126, %c64_i32 : i32 %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %143 = arith.addi %141, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %144 = arith.addi %142, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %145 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %146 = arith.muli %145, %21 : tensor<128x1xi32, #blocked1> + %143 = arith.addi %141, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %144 = arith.addi %142, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %145 = tt.expand_dims %127#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %146 = arith.muli %145, %43 : tensor<128x1xi32, #blocked1> %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> - %151 = tt.addptr %22, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %151 = tt.addptr %49, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %153 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %154 = arith.muli %153, %23 : tensor<1x256xi32, #blocked> + %153 = tt.expand_dims %127#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %154 = arith.muli %153, %53 : tensor<1x256xi32, #blocked> %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> - %158 = tt.addptr %24, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %158 = tt.addptr %58, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %159 = arith.subi %arg5, %140 : i32 %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> - %161 = arith.cmpi slt, %25, %160 : tensor<1x64xi32, #blocked1> + %161 = arith.cmpi slt, %45, %160 : tensor<1x64xi32, #blocked1> %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %163 = ttg.memdesc_subview %27[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %164 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> + %163 = ttg.memdesc_subview %21[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> + %164 = tt.splat %122 : i1 -> tensor<128x64xi1, #blocked1> %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> %167 = ttg.async_commit_group %166 %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> - %169 = arith.cmpi slt, %26, %168 : tensor<64x1xi32, #blocked> + %169 = arith.cmpi slt, %51, %168 : tensor<64x1xi32, #blocked> %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %171 = ttg.memdesc_subview %28[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %172 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> + %171 = ttg.memdesc_subview %22[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + %172 = tt.splat %122 : i1 -> tensor<64x256xi1, #blocked> %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> %175 = ttg.async_commit_group %174 - %176 = arith.cmpi eq, %arg22, %20 : i32 - %177 = arith.cmpi ne, %arg22, %20 : i32 - scf.if %176 { + %176 = arith.subi %19, %c1_i32 : i32 + %177 = arith.cmpi eq, %arg19, %176 : i32 + scf.if %177 { %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %179 = arith.muli %arg24, %c128_i32 : i32 - %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %182 = arith.muli %arg26, %c256_i32 : i32 - %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> - %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> - %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> - %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> - %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> + %179 = tt.splat %arg23 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %180 = arith.addi %179, %1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %181 = tt.splat %arg25 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %182 = arith.addi %181, %2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %183 = tt.expand_dims %180 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %184 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> + %185 = arith.muli %184, %183 : tensor<128x1xi32, #blocked2> + %186 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> + %187 = tt.addptr %186, %185 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %188 = tt.expand_dims %182 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %189 = tt.broadcast %187 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> + %190 = tt.broadcast %188 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %191 = tt.addptr %189, %190 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %192 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> + %193 = arith.cmpi slt, %183, %192 : tensor<128x1xi32, #blocked2> + %194 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> + %195 = arith.cmpi slt, %188, %194 : tensor<1x256xi32, #blocked2> + %196 = tt.broadcast %193 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %197 = tt.broadcast %195 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %198 = arith.andi %196, %197 : tensor<128x256xi1, #blocked2> + %199 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %200 = ttg.convert_layout %199 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> + tt.store %191, %200, %198 : tensor<128x256x!tt.ptr, #blocked2> } - scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %177, %139, %131, %arg21, %175, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 + scf.yield %127#4, %136#0, %127#0, %127#1, + %127#2, %127#3, %139, %130, + %126, %arg18, %arg22, %175, %arg24, %127#0, %arg26, %127#1 : i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 } - %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> - %121 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %27 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - ttg.local_dealloc %28 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> + %119 = ttng.warp_group_dot_wait %118#1 {pendings = 0 : i32} : tensor<128x256xf32, #mma> + %120 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %21 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> tt.return } } + + """ file = pathlib.Path("matmul_kernel_persistent.ttgir") @@ -964,7 +980,7 @@ def bench(K, dtype, reps=1000, warmup_reps=10000): # bench_fn(reps, warmup_reps, cublas_matmul, a, b) #if dtype == torch.float16: # bench_fn(reps, warmup_reps, torch_matmul, a, b) - #bench_fn(reps, warmup_reps, matmul, a, b.T) + bench_fn(reps, warmup_reps, matmul, a, b.T) bench_fn(reps, warmup_reps, matmul_persistent_fused, a, b.T) bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) #if supports_tma(): diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index d2ce5bdb4abd..4dfeab265a05 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -252,8 +252,10 @@ def make_ttgir(mod, metadata, opt, capability): passes.common.add_cse(pm) if capability // 10 >= 8: passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_licm(pm) passes.common.add_canonicalizer(pm) passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.common.add_canonicalizer(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) From bf1fce8aa6d7fadb606145dfe937452f5a9a36c8 Mon Sep 17 00:00:00 2001 From: Mogball Date: Sat, 25 Jan 2025 02:10:17 -0500 Subject: [PATCH 07/32] remove mlir files --- new.mlir | 374 ---------------------------------------------------- new2.mlir | 309 ------------------------------------------- new3.mlir | 291 ---------------------------------------- orig.mlir | 379 ----------------------------------------------------- orig2.mlir | 293 ----------------------------------------- test.mlir | 178 ------------------------- test2.mlir | 128 ------------------ test3.mlir | 177 ------------------------- test4.mlir | 192 --------------------------- test5.mlir | 345 ------------------------------------------------ 10 files changed, 2666 deletions(-) delete mode 100644 new.mlir delete mode 100644 new2.mlir delete mode 100644 new3.mlir delete mode 100644 orig.mlir delete mode 100644 orig2.mlir delete mode 100644 test.mlir delete mode 100644 test2.mlir delete mode 100644 test3.mlir delete mode 100644 test4.mlir delete mode 100644 test5.mlir diff --git a/new.mlir b/new.mlir deleted file mode 100644 index 7852f8582e95..000000000000 --- a/new.mlir +++ /dev/null @@ -1,374 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0) -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":270:0)) attributes {noinline = false} { - %c2_i64 = arith.constant 2 : i64 loc(#loc1) - %c3_i32 = arith.constant 3 : i32 loc(#loc1) - %c-1_i32 = arith.constant -1 : i32 loc(#loc1) - %c1_i64 = arith.constant 1 : i64 loc(#loc1) - %c0_i64 = arith.constant 0 : i64 loc(#loc1) - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) - %c256_i32 = arith.constant 256 : i32 loc(#loc1) - %c128_i32 = arith.constant 128 : i32 loc(#loc1) - %c8_i32 = arith.constant 8 : i32 loc(#loc1) - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> loc(#loc1) - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> loc(#loc1) - %c64_i32 = arith.constant 64 : i32 loc(#loc1) - %c132_i32 = arith.constant 132 : i32 loc(#loc1) - %c0_i32 = arith.constant 0 : i32 loc(#loc1) - %c1_i32 = arith.constant 1 : i32 loc(#loc1) - %c127_i32 = arith.constant 127 : i32 loc(#loc1) - %c255_i32 = arith.constant 255 : i32 loc(#loc1) - %c63_i32 = arith.constant 63 : i32 loc(#loc1) - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc59) - %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc60) - %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc61) - %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc62) - %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc63) - %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc64) - %7 = arith.muli %2, %4 : i32 loc(#loc8) - %8 = arith.muli %4, %c8_i32 : i32 loc(#loc9) - %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10) - %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc10) - %11 = arith.subi %7, %0 : i32 loc(#loc11) - %12 = arith.ceildivsi %11, %c132_i32 : i32 loc(#loc11) - %13 = arith.extsi %6 : i32 to i64 loc(#loc11) - %14 = arith.maxsi %13, %c1_i64 : i64 loc(#loc11) - %15 = arith.extsi %12 : i32 to i64 loc(#loc11) - %16 = arith.muli %15, %14 : i64 loc(#loc11) - %17 = arith.subi %0, %c132_i32 : i32 loc(#loc11) - %18 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc12) - %19 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc13) - %20 = arith.cmpi sgt, %16, %c0_i64 : i64 loc(#loc11) - %21 = arith.remsi %c0_i64, %14 : i64 loc(#loc11) - %22 = arith.cmpi eq, %21, %c0_i64 : i64 loc(#loc11) - %23 = arith.select %22, %0, %17 : i32 loc(#loc11) - %24:4 = scf.if %22 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %105 = arith.divsi %0, %8 : i32 loc(#loc14) - %106 = arith.muli %105, %c8_i32 : i32 loc(#loc15) - %107 = arith.subi %2, %106 : i32 loc(#loc16) - %108 = arith.minsi %107, %c8_i32 : i32 loc(#loc17) - %109 = arith.remsi %0, %108 : i32 loc(#loc18) - %110 = arith.addi %106, %109 : i32 loc(#loc19) - %111 = arith.remsi %0, %8 : i32 loc(#loc20) - %112 = arith.divsi %111, %108 : i32 loc(#loc21) - %113 = arith.muli %110, %c128_i32 : i32 loc(#loc22) - %114 = arith.muli %112, %c256_i32 : i32 loc(#loc23) - %115 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) - %116 = tt.splat %113 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %117 = arith.addi %116, %115 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %118 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) - %119 = tt.splat %114 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %120 = arith.addi %119, %118 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %121 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %122 = arith.cmpi slt, %117, %121 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %123 = arith.select %122, %117, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) - %124 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %125 = arith.cmpi slt, %120, %124 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %126 = arith.select %125, %120, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) - scf.yield %113, %114, %123, %126 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) - } else { - scf.yield %c0_i32, %c0_i32, %cst_0, %cst : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc11) - } loc(#loc11) - %25 = tt.expand_dims %24#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) - %26 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc33) - %27 = arith.muli %25, %26 : tensor<128x1xi32, #blocked1> loc(#loc33) - %28 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) - %29 = tt.broadcast %27 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %30 = tt.broadcast %28 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %31 = arith.addi %29, %30 : tensor<128x64xi32, #blocked1> loc(#loc35) - %32 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc36) - %33 = tt.addptr %32, %31 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) - %34 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) - %35 = tt.expand_dims %24#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) - %36 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc39) - %37 = arith.muli %35, %36 : tensor<1x256xi32, #blocked> loc(#loc39) - %38 = tt.broadcast %34 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %39 = tt.broadcast %37 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %40 = arith.addi %38, %39 : tensor<64x256xi32, #blocked> loc(#loc40) - %41 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc41) - %42 = tt.addptr %41, %40 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) - %43 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) - %44 = arith.cmpi slt, %28, %43 : tensor<1x64xi32, #blocked1> loc(#loc42) - %45 = tt.broadcast %44 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) - %46 = ttg.memdesc_subview %18[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %47 = tt.splat %20 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) - %48 = arith.andi %47, %45 : tensor<128x64xi1, #blocked1> loc(#loc11) - %49 = ttg.async_copy_global_to_local %33, %46 mask %48 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %50 = ttg.async_commit_group %49 loc(#loc12) - %51 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) - %52 = arith.cmpi slt, %34, %51 : tensor<64x1xi32, #blocked> loc(#loc43) - %53 = tt.broadcast %52 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) - %54 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %55 = tt.splat %20 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) - %56 = arith.andi %55, %53 : tensor<64x256xi1, #blocked> loc(#loc11) - %57 = ttg.async_copy_global_to_local %42, %54 mask %56 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %58 = ttg.async_commit_group %57 loc(#loc13) - %59 = arith.cmpi sgt, %16, %c1_i64 : i64 loc(#loc11) - %60 = arith.addi %21, %c1_i64 : i64 loc(#loc11) - %61 = arith.remsi %60, %14 : i64 loc(#loc11) - %62 = arith.cmpi eq, %61, %c0_i64 : i64 loc(#loc11) - %63 = arith.cmpi ne, %61, %c0_i64 : i64 loc(#loc11) - %64 = arith.extui %63 : i1 to i32 loc(#loc11) - %65:5 = scf.if %62 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %105 = arith.addi %23, %c132_i32 : i32 loc(#loc11) - %106 = arith.divsi %105, %8 : i32 loc(#loc14) - %107 = arith.muli %106, %c8_i32 : i32 loc(#loc15) - %108 = arith.subi %2, %107 : i32 loc(#loc16) - %109 = arith.minsi %108, %c8_i32 : i32 loc(#loc17) - %110 = arith.remsi %105, %109 : i32 loc(#loc18) - %111 = arith.addi %107, %110 : i32 loc(#loc19) - %112 = arith.remsi %105, %8 : i32 loc(#loc20) - %113 = arith.divsi %112, %109 : i32 loc(#loc21) - %114 = arith.muli %111, %c128_i32 : i32 loc(#loc22) - %115 = arith.muli %113, %c256_i32 : i32 loc(#loc23) - %116 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) - %117 = tt.splat %114 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %118 = arith.addi %117, %116 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %119 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) - %120 = tt.splat %115 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %121 = arith.addi %120, %119 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %122 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %123 = arith.cmpi slt, %118, %122 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %124 = arith.select %123, %118, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) - %125 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %126 = arith.cmpi slt, %121, %125 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %127 = arith.select %126, %121, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) - scf.yield %114, %115, %124, %127, %105 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) - } else { - scf.yield %24#0, %24#1, %24#2, %24#3, %23 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) - } loc(#loc11) - %66 = arith.muli %64, %c64_i32 : i32 loc(#loc44) - %67 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %68 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %69 = arith.addi %67, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %70 = arith.addi %68, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %71 = tt.expand_dims %65#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) - %72 = arith.muli %71, %26 : tensor<128x1xi32, #blocked1> loc(#loc33) - %73 = tt.expand_dims %69 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) - %74 = tt.broadcast %72 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %75 = tt.broadcast %73 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %76 = arith.addi %74, %75 : tensor<128x64xi32, #blocked1> loc(#loc35) - %77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) - %78 = tt.expand_dims %70 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) - %79 = tt.expand_dims %65#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) - %80 = arith.muli %79, %36 : tensor<1x256xi32, #blocked> loc(#loc39) - %81 = tt.broadcast %78 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %82 = tt.broadcast %80 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %83 = arith.addi %81, %82 : tensor<64x256xi32, #blocked> loc(#loc40) - %84 = tt.addptr %41, %83 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) - %85 = arith.subi %arg5, %66 : i32 loc(#loc46) - %86 = tt.splat %85 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) - %87 = arith.cmpi slt, %28, %86 : tensor<1x64xi32, #blocked1> loc(#loc42) - %88 = tt.broadcast %87 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) - %89 = ttg.memdesc_subview %18[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %90 = tt.splat %59 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) - %91 = arith.andi %90, %88 : tensor<128x64xi1, #blocked1> loc(#loc11) - %92 = ttg.async_copy_global_to_local %77, %89 mask %91 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %93 = ttg.async_commit_group %92 loc(#loc12) - %94 = tt.splat %85 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) - %95 = arith.cmpi slt, %34, %94 : tensor<64x1xi32, #blocked> loc(#loc43) - %96 = tt.broadcast %95 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) - %97 = ttg.memdesc_subview %19[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %98 = tt.splat %59 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) - %99 = arith.andi %98, %96 : tensor<64x256xi1, #blocked> loc(#loc11) - %100 = ttg.async_copy_global_to_local %84, %97 mask %99 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %101 = ttg.async_commit_group %100 loc(#loc13) - %102:18 = scf.for %arg9 = %c0_i64 to %16 step %c1_i64 iter_args(%arg10 = %61, %arg11 = %65#4, %arg12 = %cst_3, %arg13 = %65#0, %arg14 = %65#1, %arg15 = %65#2, %arg16 = %65#3, %arg17 = %c1_i32, %arg18 = %c-1_i32, %arg19 = %64, %arg20 = %21, %arg21 = %61, %arg22 = %58, %arg23 = %101, %arg24 = %24#0, %arg25 = %65#0, %arg26 = %24#1, %arg27 = %65#1) -> (i64, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i64, i64, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i64 { - %105 = arith.subi %16, %c2_i64 : i64 loc(#loc11) - %106 = arith.cmpi slt, %arg9, %105 : i64 loc(#loc11) - %107 = arith.addi %arg19, %c1_i32 : i32 loc(#loc11) - %108 = arith.addi %arg10, %c1_i64 : i64 loc(#loc11) - %109 = arith.remsi %108, %14 : i64 loc(#loc11) - %110 = arith.cmpi eq, %109, %c0_i64 : i64 loc(#loc11) - %111 = arith.select %110, %c0_i32, %107 : i32 loc(#loc11) - %112:5 = scf.if %110 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %163 = arith.addi %arg11, %c132_i32 : i32 loc(#loc11) - %164 = arith.divsi %163, %8 : i32 loc(#loc14) - %165 = arith.muli %164, %c8_i32 : i32 loc(#loc15) - %166 = arith.subi %2, %165 : i32 loc(#loc16) - %167 = arith.minsi %166, %c8_i32 : i32 loc(#loc17) - %168 = arith.remsi %163, %167 : i32 loc(#loc18) - %169 = arith.addi %165, %168 : i32 loc(#loc19) - %170 = arith.remsi %163, %8 : i32 loc(#loc20) - %171 = arith.divsi %170, %167 : i32 loc(#loc21) - %172 = arith.muli %169, %c128_i32 : i32 loc(#loc22) - %173 = arith.muli %171, %c256_i32 : i32 loc(#loc23) - %174 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) - %175 = tt.splat %172 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %176 = arith.addi %175, %174 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc25) - %177 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc26) - %178 = tt.splat %173 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %179 = arith.addi %178, %177 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %180 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %181 = arith.cmpi slt, %176, %180 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc28) - %182 = arith.select %181, %176, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) - %183 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %184 = arith.cmpi slt, %179, %183 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc30) - %185 = arith.select %184, %179, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc31) - scf.yield %172, %173, %182, %185, %163 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) - } else { - scf.yield %arg13, %arg14, %arg15, %arg16, %arg11 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 loc(#loc11) - } loc(#loc11) - %113 = arith.addi %arg18, %c1_i32 : i32 loc(#loc11) - %114 = arith.cmpi slt, %113, %c3_i32 : i32 loc(#loc11) - %115 = arith.select %114, %113, %c0_i32 : i32 loc(#loc11) - %116 = arith.cmpi ne, %arg20, %c0_i64 : i64 loc(#loc65) - %117 = ttg.memdesc_subview %18[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %118 = ttg.async_wait %arg22 {num = 2 : i32} loc(#loc12) - %119 = ttg.memdesc_subview %19[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %120 = ttng.warp_group_dot %117, %119, %arg12, %116 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc47) - %121:3 = ttng.warp_group_dot_wait %120, %117, %119 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) - %122 = arith.addi %arg17, %c1_i32 : i32 loc(#loc11) - %123 = arith.cmpi slt, %122, %c3_i32 : i32 loc(#loc11) - %124 = arith.select %123, %122, %c0_i32 : i32 loc(#loc11) - %125 = arith.muli %111, %c64_i32 : i32 loc(#loc44) - %126 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %127 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %128 = arith.addi %126, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc45) - %129 = arith.addi %127, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc45) - %130 = tt.expand_dims %112#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc32) - %131 = arith.muli %130, %26 : tensor<128x1xi32, #blocked1> loc(#loc33) - %132 = tt.expand_dims %128 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc34) - %133 = tt.broadcast %131 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %134 = tt.broadcast %132 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc35) - %135 = arith.addi %133, %134 : tensor<128x64xi32, #blocked1> loc(#loc35) - %136 = tt.addptr %32, %135 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc36) - %137 = tt.expand_dims %129 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc37) - %138 = tt.expand_dims %112#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc38) - %139 = arith.muli %138, %36 : tensor<1x256xi32, #blocked> loc(#loc39) - %140 = tt.broadcast %137 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %141 = tt.broadcast %139 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc40) - %142 = arith.addi %140, %141 : tensor<64x256xi32, #blocked> loc(#loc40) - %143 = tt.addptr %41, %142 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc41) - %144 = arith.subi %arg5, %125 : i32 loc(#loc46) - %145 = tt.splat %144 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc42) - %146 = arith.cmpi slt, %28, %145 : tensor<1x64xi32, #blocked1> loc(#loc42) - %147 = tt.broadcast %146 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc12) - %148 = ttg.memdesc_subview %18[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %149 = tt.splat %106 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc11) - %150 = arith.andi %149, %147 : tensor<128x64xi1, #blocked1> loc(#loc11) - %151 = ttg.async_copy_global_to_local %136, %148 mask %150 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc12) - %152 = ttg.async_commit_group %151 loc(#loc12) - %153 = tt.splat %144 : i32 -> tensor<64x1xi32, #blocked> loc(#loc43) - %154 = arith.cmpi slt, %34, %153 : tensor<64x1xi32, #blocked> loc(#loc43) - %155 = tt.broadcast %154 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc13) - %156 = ttg.memdesc_subview %19[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %157 = tt.splat %106 : i1 -> tensor<64x256xi1, #blocked> loc(#loc11) - %158 = arith.andi %157, %155 : tensor<64x256xi1, #blocked> loc(#loc11) - %159 = ttg.async_copy_global_to_local %143, %156 mask %158 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc13) - %160 = ttg.async_commit_group %159 loc(#loc13) - %161 = arith.subi %14, %c1_i64 : i64 loc(#loc11) - %162 = arith.cmpi eq, %arg20, %161 : i64 loc(#loc11) - scf.if %162 { - %163:3 = ttng.warp_group_dot_wait %121#0, %117, %119 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc47) - %164 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc24) - %165 = tt.splat %arg24 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) - %166 = arith.addi %165, %164 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc25) - %167 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc26) - %168 = tt.splat %arg26 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) - %169 = arith.addi %168, %167 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc27) - %170 = tt.expand_dims %166 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc48) - %171 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc49) - %172 = arith.muli %171, %170 : tensor<128x1xi32, #blocked2> loc(#loc49) - %173 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc50) - %174 = tt.addptr %173, %172 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc50) - %175 = tt.expand_dims %169 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc51) - %176 = tt.broadcast %174 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc52) - %177 = tt.broadcast %175 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc52) - %178 = tt.addptr %176, %177 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc52) - %179 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc53) - %180 = arith.cmpi slt, %170, %179 : tensor<128x1xi32, #blocked2> loc(#loc53) - %181 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc54) - %182 = arith.cmpi slt, %175, %181 : tensor<1x256xi32, #blocked2> loc(#loc54) - %183 = tt.broadcast %180 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) - %184 = tt.broadcast %182 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc55) - %185 = arith.andi %183, %184 : tensor<128x256xi1, #blocked2> loc(#loc55) - %186 = arith.truncf %163#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc56) - %187 = ttg.convert_layout %186 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc57) - tt.store %178, %187, %185 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc57) - } loc(#loc11) - scf.yield %109, %112#4, %121#0, %112#0, %112#1, %112#2, %112#3, %124, %115, %111, %arg21, %109, %arg23, %160, %arg25, %112#0, %arg27, %112#1 : i64, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i64, i64, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 loc(#loc11) - } loc(#loc11) - %103 = ttng.warp_group_dot_wait %102#2 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc11) - %104 = ttg.async_wait {num = 0 : i32} loc(#loc11) - ttg.local_dealloc %18 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc11) - ttg.local_dealloc %19 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc11) - tt.return loc(#loc58) - } loc(#loc) -} loc(#loc) -#loc1 = loc(unknown) -#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":281:30) -#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) -#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":282:27) -#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) -#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":283:27) -#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":284:25) -#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":285:28) -#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":286:38) -#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":288:35) -#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":290:47) -#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":312:24) -#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":313:24) -#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":291:30) -#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":292:33) -#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":293:39) -#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":293:52) -#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":294:41) -#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":294:31) -#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":295:27) -#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":295:48) -#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":297:26) -#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":298:26) -#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":299:41) -#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":299:28) -#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":300:41) -#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":300:28) -#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":301:37) -#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":301:49) -#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":302:37) -#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":302:49) -#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:38) -#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:49) -#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:68) -#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:61) -#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":309:30) -#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:37) -#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:68) -#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:79) -#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:60) -#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":310:30) -#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":312:64) -#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":313:64) -#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":308:26) -#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":308:41) -#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":312:68) -#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":314:39) -#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:45) -#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:37) -#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:25) -#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:76) -#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":318:56) -#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":319:37) -#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":319:62) -#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":319:43) -#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":323:31) -#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":324:25) -#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":290:4) -#loc59 = loc(callsite(#loc3 at #loc4)) -#loc60 = loc(callsite(#loc5 at #loc4)) -#loc61 = loc(callsite(#loc3 at #loc6)) -#loc62 = loc(callsite(#loc5 at #loc6)) -#loc63 = loc(callsite(#loc3 at #loc7)) -#loc64 = loc(callsite(#loc5 at #loc7)) -#loc65 = loc(fused[#loc47, #loc11]) - diff --git a/new2.mlir b/new2.mlir deleted file mode 100644 index 9e82a5ef1817..000000000000 --- a/new2.mlir +++ /dev/null @@ -1,309 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c2_i64 = arith.constant 2 : i32 - %c3_i32 = arith.constant 3 : i32 - %c-1_i32 = arith.constant -1 : i32 - %c1_i64 = arith.constant 1 : i32 - %c0_i64 = arith.constant 0 : i32 - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %range_1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %range_2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %range_3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %range_4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %splat_1 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %splat_2 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.addi %arg5, %c63_i32 : i32 - %6 = arith.divsi %5, %c64_i32 : i32 - %7 = arith.muli %2, %4 : i32 - %8 = arith.muli %4, %c8_i32 : i32 - %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %11 = arith.subi %7, %0 : i32 - %12 = arith.ceildivsi %11, %c132_i32 : i32 - %13 = arith.addi %6, %c0_i32 : i32 - %14 = arith.maxsi %13, %c1_i64 : i32 - %15 = arith.addi %12, %c0_i32 : i32 - %16 = arith.muli %15, %14 : i32 - %17 = arith.subi %0, %c132_i32 : i32 - %18 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - %19 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - %20 = arith.cmpi sgt, %16, %c0_i64 : i32 - %21 = arith.constant 0 : i32 - %22 = arith.cmpi eq, %21, %c0_i64 : i32 - %23 = arith.select %22, %0, %17 : i32 - %24:4 = scf.if %22 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %105 = arith.divsi %0, %8 : i32 - %106 = arith.muli %105, %c8_i32 : i32 - %107 = arith.subi %2, %106 : i32 - %108 = arith.minsi %107, %c8_i32 : i32 - %109 = arith.remsi %0, %108 : i32 - %110 = arith.addi %106, %109 : i32 - %111 = arith.remsi %0, %8 : i32 - %112 = arith.divsi %111, %108 : i32 - %113 = arith.muli %110, %c128_i32 : i32 - %114 = arith.muli %112, %c256_i32 : i32 - %115 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %116 = tt.splat %113 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %117 = arith.addi %116, %115 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %118 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %119 = tt.splat %114 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %120 = arith.addi %119, %118 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %121 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %122 = arith.cmpi slt, %117, %121 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %123 = arith.select %122, %117, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %124 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %125 = arith.cmpi slt, %120, %124 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %126 = arith.select %125, %120, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %113, %114, %123, %126 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - } else { - scf.yield %c0_i32, %c0_i32, %cst_0, %cst : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - } - %25 = tt.expand_dims %24#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %26 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %27 = arith.muli %25, %26 : tensor<128x1xi32, #blocked1> - %28 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %29 = tt.broadcast %27 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %30 = tt.broadcast %28 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %31 = arith.addi %29, %30 : tensor<128x64xi32, #blocked1> - %32 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %33 = tt.addptr %32, %31 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %34 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %35 = tt.expand_dims %24#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %36 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %37 = arith.muli %35, %36 : tensor<1x256xi32, #blocked> - %38 = tt.broadcast %34 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %39 = tt.broadcast %37 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %40 = arith.addi %38, %39 : tensor<64x256xi32, #blocked> - %41 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %42 = tt.addptr %41, %40 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %43 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> - %44 = arith.cmpi slt, %28, %43 : tensor<1x64xi32, #blocked1> - %45 = tt.broadcast %44 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %46 = ttg.memdesc_subview %18[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %47 = tt.splat %20 : i1 -> tensor<128x64xi1, #blocked1> - %48 = arith.andi %47, %45 : tensor<128x64xi1, #blocked1> - %49 = ttg.async_copy_global_to_local %33, %46 mask %48 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %50 = ttg.async_commit_group %49 - %51 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> - %52 = arith.cmpi slt, %34, %51 : tensor<64x1xi32, #blocked> - %53 = tt.broadcast %52 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %54 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %55 = tt.splat %20 : i1 -> tensor<64x256xi1, #blocked> - %56 = arith.andi %55, %53 : tensor<64x256xi1, #blocked> - %57 = ttg.async_copy_global_to_local %42, %54 mask %56 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %58 = ttg.async_commit_group %57 - %59 = arith.cmpi sgt, %16, %c1_i64 : i32 - %60 = arith.addi %21, %c1_i64 : i32 - %61 = arith.remsi %60, %14 : i32 - %62 = arith.cmpi eq, %61, %c0_i64 : i32 - %63 = arith.cmpi ne, %61, %c0_i64 : i32 - %64 = arith.extui %63 : i1 to i32 - %65:5 = scf.if %62 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %105 = arith.addi %23, %c132_i32 : i32 - %106 = arith.divsi %105, %8 : i32 - %107 = arith.muli %106, %c8_i32 : i32 - %108 = arith.subi %2, %107 : i32 - %109 = arith.minsi %108, %c8_i32 : i32 - %110 = arith.remsi %105, %109 : i32 - %111 = arith.addi %107, %110 : i32 - %112 = arith.remsi %105, %8 : i32 - %113 = arith.divsi %112, %109 : i32 - %114 = arith.muli %111, %c128_i32 : i32 - %115 = arith.muli %113, %c256_i32 : i32 - %116 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %117 = tt.splat %114 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %118 = arith.addi %117, %116 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %119 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %120 = tt.splat %115 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %121 = arith.addi %120, %119 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %122 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %123 = arith.cmpi slt, %118, %122 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %124 = arith.select %123, %118, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %125 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %126 = arith.cmpi slt, %121, %125 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %127 = arith.select %126, %121, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %114, %115, %124, %127, %105 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } else { - scf.yield %24#0, %24#1, %24#2, %24#3, %23 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } - %66 = arith.muli %64, %c64_i32 : i32 - %67 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %68 = tt.splat %66 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %69 = arith.addi %67, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %70 = arith.addi %68, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %71 = tt.expand_dims %65#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %72 = arith.muli %71, %26 : tensor<128x1xi32, #blocked1> - %73 = tt.expand_dims %69 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %74 = tt.broadcast %72 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %75 = tt.broadcast %73 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %76 = arith.addi %74, %75 : tensor<128x64xi32, #blocked1> - %77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %78 = tt.expand_dims %70 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %79 = tt.expand_dims %65#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %80 = arith.muli %79, %36 : tensor<1x256xi32, #blocked> - %81 = tt.broadcast %78 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %82 = tt.broadcast %80 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %83 = arith.addi %81, %82 : tensor<64x256xi32, #blocked> - %84 = tt.addptr %41, %83 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %85 = arith.subi %arg5, %66 : i32 - %86 = tt.splat %85 : i32 -> tensor<1x64xi32, #blocked1> - %87 = arith.cmpi slt, %28, %86 : tensor<1x64xi32, #blocked1> - %88 = tt.broadcast %87 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %89 = ttg.memdesc_subview %18[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %90 = tt.splat %59 : i1 -> tensor<128x64xi1, #blocked1> - %91 = arith.andi %90, %88 : tensor<128x64xi1, #blocked1> - %92 = ttg.async_copy_global_to_local %77, %89 mask %91 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %93 = ttg.async_commit_group %92 - %94 = tt.splat %85 : i32 -> tensor<64x1xi32, #blocked> - %95 = arith.cmpi slt, %34, %94 : tensor<64x1xi32, #blocked> - %96 = tt.broadcast %95 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %97 = ttg.memdesc_subview %19[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %98 = tt.splat %59 : i1 -> tensor<64x256xi1, #blocked> - %99 = arith.andi %98, %96 : tensor<64x256xi1, #blocked> - %100 = ttg.async_copy_global_to_local %84, %97 mask %99 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %101 = ttg.async_commit_group %100 - %102:18 = scf.for %arg9 = %c0_i64 to %16 step %c1_i64 iter_args(%arg10 = %61, %arg11 = %65#4, %arg12 = %cst_3, %arg13 = %65#0, %arg14 = %65#1, %arg15 = %65#2, %arg16 = %65#3, %arg17 = %c1_i32, %arg18 = %c-1_i32, %arg19 = %64, %arg20 = %21, %arg21 = %61, %arg22 = %58, %arg23 = %101, %arg24 = %24#0, %arg25 = %65#0, %arg26 = %24#1, %arg27 = %65#1) -> (i32, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i32 { - %105 = arith.subi %16, %c2_i64 : i32 - %106 = arith.cmpi slt, %arg9, %105 : i32 - %107 = arith.addi %arg19, %c1_i32 : i32 - %108 = arith.addi %107, %c0_i32 : i32 - %109 = arith.remsi %108, %14 : i32 - %110 = arith.cmpi eq, %109, %c0_i64 : i32 - %111 = arith.select %110, %c0_i32, %107 : i32 - %112:5 = scf.if %110 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %163 = arith.addi %arg11, %c132_i32 : i32 - %164 = arith.divsi %163, %8 : i32 - %165 = arith.muli %164, %c8_i32 : i32 - %166 = arith.subi %2, %165 : i32 - %167 = arith.minsi %166, %c8_i32 : i32 - %168 = arith.remsi %163, %167 : i32 - %169 = arith.addi %165, %168 : i32 - %170 = arith.remsi %163, %8 : i32 - %171 = arith.divsi %170, %167 : i32 - %172 = arith.muli %169, %c128_i32 : i32 - %173 = arith.muli %171, %c256_i32 : i32 - %175 = tt.splat %172 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %176 = arith.addi %175, %range_3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %178 = tt.splat %173 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %179 = arith.addi %178, %range_4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %181 = arith.cmpi slt, %176, %splat_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %182 = arith.select %181, %176, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %184 = arith.cmpi slt, %179, %splat_2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %185 = arith.select %184, %179, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %172, %173, %182, %185, %163 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } else { - scf.yield %arg13, %arg14, %arg15, %arg16, %arg11 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } - %113 = arith.addi %arg18, %c1_i32 : i32 - %114 = arith.cmpi slt, %113, %c3_i32 : i32 - %115 = arith.select %114, %113, %c0_i32 : i32 - %116 = arith.cmpi ne, %arg20, %c0_i64 : i32 - %117 = ttg.memdesc_subview %18[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %118 = ttg.async_wait %arg22 {num = 2 : i32} - %119 = ttg.memdesc_subview %19[%115, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %120 = ttng.warp_group_dot %117, %119, %arg12, %116 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> - %121:3 = ttng.warp_group_dot_wait %120, %117, %119 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %122 = arith.addi %arg17, %c1_i32 : i32 - %123 = arith.cmpi slt, %122, %c3_i32 : i32 - %124 = arith.select %123, %122, %c0_i32 : i32 - %125 = arith.muli %111, %c64_i32 : i32 - %126 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %127 = tt.splat %125 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %128 = arith.addi %126, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %129 = arith.addi %127, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %130 = tt.expand_dims %112#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %131 = arith.muli %130, %26 : tensor<128x1xi32, #blocked1> - %132 = tt.expand_dims %128 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %133 = tt.broadcast %131 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %134 = tt.broadcast %132 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %135 = arith.addi %133, %134 : tensor<128x64xi32, #blocked1> - %136 = tt.addptr %32, %135 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %137 = tt.expand_dims %129 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %138 = tt.expand_dims %112#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %139 = arith.muli %138, %36 : tensor<1x256xi32, #blocked> - %140 = tt.broadcast %137 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %141 = tt.broadcast %139 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %142 = arith.addi %140, %141 : tensor<64x256xi32, #blocked> - %143 = tt.addptr %41, %142 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %144 = arith.subi %arg5, %125 : i32 - %145 = tt.splat %144 : i32 -> tensor<1x64xi32, #blocked1> - %146 = arith.cmpi slt, %28, %145 : tensor<1x64xi32, #blocked1> - %147 = tt.broadcast %146 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %148 = ttg.memdesc_subview %18[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %149 = tt.splat %106 : i1 -> tensor<128x64xi1, #blocked1> - %150 = arith.andi %149, %147 : tensor<128x64xi1, #blocked1> - %151 = ttg.async_copy_global_to_local %136, %148 mask %150 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %152 = ttg.async_commit_group %151 - %153 = tt.splat %144 : i32 -> tensor<64x1xi32, #blocked> - %154 = arith.cmpi slt, %34, %153 : tensor<64x1xi32, #blocked> - %155 = tt.broadcast %154 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %156 = ttg.memdesc_subview %19[%124, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %157 = tt.splat %106 : i1 -> tensor<64x256xi1, #blocked> - %158 = arith.andi %157, %155 : tensor<64x256xi1, #blocked> - %159 = ttg.async_copy_global_to_local %143, %156 mask %158 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %160 = ttg.async_commit_group %159 - %161 = arith.subi %14, %c1_i64 : i32 - %162 = arith.cmpi eq, %arg20, %161 : i32 - scf.if %162 { - %163:3 = ttng.warp_group_dot_wait %121#0, %117, %119 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %165 = tt.splat %arg24 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %166 = arith.addi %165, %range_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %168 = tt.splat %arg26 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %169 = arith.addi %168, %range_2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %170 = tt.expand_dims %166 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %171 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %172 = arith.muli %171, %170 : tensor<128x1xi32, #blocked2> - %173 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %174 = tt.addptr %173, %172 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %175 = tt.expand_dims %169 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %176 = tt.broadcast %174 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %177 = tt.broadcast %175 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %178 = tt.addptr %176, %177 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %179 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %180 = arith.cmpi slt, %170, %179 : tensor<128x1xi32, #blocked2> - %181 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %182 = arith.cmpi slt, %175, %181 : tensor<1x256xi32, #blocked2> - %183 = tt.broadcast %180 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %184 = tt.broadcast %182 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %185 = arith.andi %183, %184 : tensor<128x256xi1, #blocked2> - %186 = arith.truncf %163#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %187 = ttg.convert_layout %186 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %178, %187, %185 : tensor<128x256x!tt.ptr, #blocked2> - } - scf.yield %109, %112#4, %121#0, %112#0, %112#1, %112#2, %112#3, %124, %115, %111, %arg21, %109, %arg23, %160, %arg25, %112#0, %arg27, %112#1 : i32, i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 - } - %103 = ttng.warp_group_dot_wait %102#2 {pendings = 0 : i32} : tensor<128x256xf32, #mma> - %104 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %18 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - ttg.local_dealloc %19 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - tt.return - } -} - diff --git a/new3.mlir b/new3.mlir deleted file mode 100644 index 3b5823443c13..000000000000 --- a/new3.mlir +++ /dev/null @@ -1,291 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c2_i32 = arith.constant 2 : i32 - %c3_i32 = arith.constant 3 : i32 - %c-1_i32 = arith.constant -1 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %5 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %6 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %7 = arith.addi %arg3, %c127_i32 : i32 - %8 = arith.divsi %7, %c128_i32 : i32 - %9 = arith.addi %arg4, %c255_i32 : i32 - %10 = arith.divsi %9, %c256_i32 : i32 - %11 = arith.addi %arg5, %c63_i32 : i32 - %12 = arith.divsi %11, %c64_i32 : i32 - %13 = arith.muli %8, %10 : i32 - %14 = arith.muli %10, %c8_i32 : i32 - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %17 = arith.subi %13, %0 : i32 - %18 = arith.ceildivsi %17, %c132_i32 : i32 - %19 = arith.maxsi %12, %c1_i32 : i32 - %20 = arith.muli %18, %19 : i32 - %21 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - %22 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - %23 = arith.cmpi sgt, %20, %c0_i32 : i32 - %24 = arith.divsi %0, %14 : i32 - %25 = arith.muli %24, %c8_i32 : i32 - %26 = arith.subi %8, %25 : i32 - %27 = arith.minsi %26, %c8_i32 : i32 - %28 = arith.remsi %0, %27 : i32 - %29 = arith.addi %25, %28 : i32 - %30 = arith.remsi %0, %14 : i32 - %31 = arith.divsi %30, %27 : i32 - %32 = arith.muli %29, %c128_i32 : i32 - %33 = arith.muli %31, %c256_i32 : i32 - %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %35 = arith.addi %34, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %37 = arith.addi %36, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %38 = arith.cmpi slt, %35, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %39 = arith.select %38, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %40 = arith.cmpi slt, %37, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %41 = arith.select %40, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %43 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %44 = arith.muli %42, %43 : tensor<128x1xi32, #blocked1> - %45 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %46 = tt.broadcast %44 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %47 = tt.broadcast %45 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %48 = arith.addi %46, %47 : tensor<128x64xi32, #blocked1> - %49 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %50 = tt.addptr %49, %48 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %51 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %52 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %53 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %54 = arith.muli %52, %53 : tensor<1x256xi32, #blocked> - %55 = tt.broadcast %51 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %56 = tt.broadcast %54 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %57 = arith.addi %55, %56 : tensor<64x256xi32, #blocked> - %58 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %59 = tt.addptr %58, %57 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %60 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> - %61 = arith.cmpi slt, %45, %60 : tensor<1x64xi32, #blocked1> - %62 = tt.broadcast %61 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %63 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %64 = tt.splat %23 : i1 -> tensor<128x64xi1, #blocked1> - %65 = arith.andi %64, %62 : tensor<128x64xi1, #blocked1> - %66 = ttg.async_copy_global_to_local %50, %63 mask %65 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %67 = ttg.async_commit_group %66 - %68 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> - %69 = arith.cmpi slt, %51, %68 : tensor<64x1xi32, #blocked> - %70 = tt.broadcast %69 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %71 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %72 = tt.splat %23 : i1 -> tensor<64x256xi1, #blocked> - %73 = arith.andi %72, %70 : tensor<64x256xi1, #blocked> - %74 = ttg.async_copy_global_to_local %59, %71 mask %73 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %75 = ttg.async_commit_group %74 - %76 = arith.cmpi sgt, %20, %c1_i32 : i32 - %77 = arith.remsi %c1_i32, %19 : i32 - %78 = arith.cmpi eq, %77, %c0_i32 : i32 - %79 = arith.cmpi ne, %77, %c0_i32 : i32 - %80 = arith.extui %79 : i1 to i32 - %81:5 = scf.if %78 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %121 = arith.addi %0, %c132_i32 : i32 - %122 = arith.divsi %121, %14 : i32 - %123 = arith.muli %122, %c8_i32 : i32 - %124 = arith.subi %8, %123 : i32 - %125 = arith.minsi %124, %c8_i32 : i32 - %126 = arith.remsi %121, %125 : i32 - %127 = arith.addi %123, %126 : i32 - %128 = arith.remsi %121, %14 : i32 - %129 = arith.divsi %128, %125 : i32 - %130 = arith.muli %127, %c128_i32 : i32 - %131 = arith.muli %129, %c256_i32 : i32 - %132 = tt.splat %130 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %133 = arith.addi %132, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %134 = tt.splat %131 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %135 = arith.addi %134, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %136 = arith.cmpi slt, %133, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %137 = arith.select %136, %133, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %138 = arith.cmpi slt, %135, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %139 = arith.select %138, %135, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %130, %131, %137, %139, %121 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } else { - scf.yield %32, %33, %39, %41, %0 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } - %82 = arith.muli %80, %c64_i32 : i32 - %83 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %84 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %85 = arith.addi %83, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %86 = arith.addi %84, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %87 = tt.expand_dims %81#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %88 = arith.muli %87, %43 : tensor<128x1xi32, #blocked1> - %89 = tt.expand_dims %85 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %90 = tt.broadcast %88 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %91 = tt.broadcast %89 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %92 = arith.addi %90, %91 : tensor<128x64xi32, #blocked1> - %93 = tt.addptr %49, %92 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %94 = tt.expand_dims %86 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %95 = tt.expand_dims %81#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %96 = arith.muli %95, %53 : tensor<1x256xi32, #blocked> - %97 = tt.broadcast %94 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %98 = tt.broadcast %96 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %99 = arith.addi %97, %98 : tensor<64x256xi32, #blocked> - %100 = tt.addptr %58, %99 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %101 = arith.subi %arg5, %82 : i32 - %102 = tt.splat %101 : i32 -> tensor<1x64xi32, #blocked1> - %103 = arith.cmpi slt, %45, %102 : tensor<1x64xi32, #blocked1> - %104 = tt.broadcast %103 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %105 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %106 = tt.splat %76 : i1 -> tensor<128x64xi1, #blocked1> - %107 = arith.andi %106, %104 : tensor<128x64xi1, #blocked1> - %108 = ttg.async_copy_global_to_local %93, %105 mask %107 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %109 = ttg.async_commit_group %108 - %110 = tt.splat %101 : i32 -> tensor<64x1xi32, #blocked> - %111 = arith.cmpi slt, %51, %110 : tensor<64x1xi32, #blocked> - %112 = tt.broadcast %111 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %113 = ttg.memdesc_subview %22[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %114 = tt.splat %76 : i1 -> tensor<64x256xi1, #blocked> - %115 = arith.andi %114, %112 : tensor<64x256xi1, #blocked> - %116 = ttg.async_copy_global_to_local %100, %113 mask %115 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %117 = ttg.async_commit_group %116 - %lol = arith.subi %12, %c1_i32 : i32 - %118:16 = scf.for %arg9 = %c0_i32 to %20 step %c1_i32 iter_args( - %arg10 = %81#4, %arg11 = %cst_3, %arg12 = %81#0, %arg13 = %81#1, - %arg14 = %81#2, %arg15 = %81#3, %arg16 = %c1_i32, %arg17 = %c-1_i32, - %arg18 = %80, %arg19 = %c0_i32, %arg21 = %75, %arg22 = %117, %arg23 = %32, %arg24 = %81#0, %arg25 = %33, %arg26 = %81#1) -> (i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i32 { - %121 = arith.subi %20, %c2_i32 : i32 - %122 = arith.cmpi slt, %arg9, %121 : i32 - %rollover = arith.cmpi eq, %arg18, %lol : i32 - %123 = arith.addi %arg18, %c1_i32 : i32 - %126 = arith.select %rollover, %c0_i32, %123 : i32 - %125 = arith.cmpi eq, %126, %c0_i32 : i32 - %127:5 = scf.if %125 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %178 = arith.addi %arg10, %c132_i32 : i32 - %179 = arith.divsi %178, %14 : i32 - %180 = arith.muli %179, %c8_i32 : i32 - %181 = arith.subi %8, %180 : i32 - %182 = arith.minsi %181, %c8_i32 : i32 - %183 = arith.remsi %178, %182 : i32 - %184 = arith.addi %180, %183 : i32 - %185 = arith.remsi %178, %14 : i32 - %186 = arith.divsi %185, %182 : i32 - %187 = arith.muli %184, %c128_i32 : i32 - %188 = arith.muli %186, %c256_i32 : i32 - %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %190 = arith.addi %189, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %192 = arith.addi %191, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %193 = arith.cmpi slt, %190, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %195 = arith.cmpi slt, %192, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %187, %188, %194, %196, %178 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } else { - scf.yield %arg12, %arg13, %arg14, %arg15, %arg10 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } - %128 = arith.addi %arg17, %c1_i32 : i32 - %129 = arith.cmpi slt, %128, %c3_i32 : i32 - %130 = arith.select %129, %128, %c0_i32 : i32 - %131 = arith.cmpi ne, %arg19, %c0_i32 : i32 - %132 = ttg.memdesc_subview %21[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %133 = ttg.async_wait %arg21 {num = 2 : i32} - %134 = ttg.memdesc_subview %22[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %135 = ttng.warp_group_dot %132, %134, %arg11, %131 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> - %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %137 = arith.addi %arg16, %c1_i32 : i32 - %138 = arith.cmpi slt, %137, %c3_i32 : i32 - %139 = arith.select %138, %137, %c0_i32 : i32 - %140 = arith.muli %126, %c64_i32 : i32 - %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %143 = arith.addi %141, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %144 = arith.addi %142, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %145 = tt.expand_dims %127#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %146 = arith.muli %145, %43 : tensor<128x1xi32, #blocked1> - %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> - %151 = tt.addptr %49, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %153 = tt.expand_dims %127#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %154 = arith.muli %153, %53 : tensor<1x256xi32, #blocked> - %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> - %158 = tt.addptr %58, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %159 = arith.subi %arg5, %140 : i32 - %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> - %161 = arith.cmpi slt, %45, %160 : tensor<1x64xi32, #blocked1> - %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %163 = ttg.memdesc_subview %21[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %164 = tt.splat %122 : i1 -> tensor<128x64xi1, #blocked1> - %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> - %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %167 = ttg.async_commit_group %166 - %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> - %169 = arith.cmpi slt, %51, %168 : tensor<64x1xi32, #blocked> - %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %171 = ttg.memdesc_subview %22[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %172 = tt.splat %122 : i1 -> tensor<64x256xi1, #blocked> - %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> - %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %175 = ttg.async_commit_group %174 - %176 = arith.subi %19, %c1_i32 : i32 - %177 = arith.cmpi eq, %arg19, %176 : i32 - scf.if %177 { - %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %179 = tt.splat %arg23 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %180 = arith.addi %179, %1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %181 = tt.splat %arg25 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %182 = arith.addi %181, %2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %183 = tt.expand_dims %180 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %184 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %185 = arith.muli %184, %183 : tensor<128x1xi32, #blocked2> - %186 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %187 = tt.addptr %186, %185 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %188 = tt.expand_dims %182 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %189 = tt.broadcast %187 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %190 = tt.broadcast %188 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %191 = tt.addptr %189, %190 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %192 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %193 = arith.cmpi slt, %183, %192 : tensor<128x1xi32, #blocked2> - %194 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %195 = arith.cmpi slt, %188, %194 : tensor<1x256xi32, #blocked2> - %196 = tt.broadcast %193 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %197 = tt.broadcast %195 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %198 = arith.andi %196, %197 : tensor<128x256xi1, #blocked2> - %199 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %200 = ttg.convert_layout %199 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %191, %200, %198 : tensor<128x256x!tt.ptr, #blocked2> - } - scf.yield %127#4, %136#0, %127#0, %127#1, - %127#2, %127#3, %139, %130, - %126, %arg18, %arg22, %175, %arg24, %127#0, %arg26, %127#1 : i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 - } - %119 = ttng.warp_group_dot_wait %118#1 {pendings = 0 : i32} : tensor<128x256xf32, #mma> - %120 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %21 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - ttg.local_dealloc %22 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - tt.return - } -} - diff --git a/orig.mlir b/orig.mlir deleted file mode 100644 index 0af6b4b63b38..000000000000 --- a/orig.mlir +++ /dev/null @@ -1,379 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0) -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":156:0)) attributes {noinline = false} { - %c2_i32 = arith.constant 2 : i32 loc(#loc1) - %c3_i32 = arith.constant 3 : i32 loc(#loc1) - %false = arith.constant false loc(#loc1) - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc1) - %c256_i32 = arith.constant 256 : i32 loc(#loc1) - %c128_i32 = arith.constant 128 : i32 loc(#loc1) - %c0_i32 = arith.constant 0 : i32 loc(#loc1) - %c8_i32 = arith.constant 8 : i32 loc(#loc1) - %c-1_i32 = arith.constant -1 : i32 loc(#loc1) - %c1_i32 = arith.constant 1 : i32 loc(#loc1) - %c132_i32 = arith.constant 132 : i32 loc(#loc1) - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> loc(#loc1) - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> loc(#loc1) - %c64_i32 = arith.constant 64 : i32 loc(#loc1) - %c127_i32 = arith.constant 127 : i32 loc(#loc1) - %c255_i32 = arith.constant 255 : i32 loc(#loc1) - %c63_i32 = arith.constant 63 : i32 loc(#loc1) - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc78) - %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc79) - %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc80) - %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc81) - %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc82) - %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc83) - %7 = arith.muli %2, %4 : i32 loc(#loc8) - %8 = arith.divsi %7, %c132_i32 : i32 loc(#loc9) - %9 = arith.remsi %7, %c132_i32 : i32 loc(#loc10) - %10 = arith.cmpi slt, %0, %9 : i32 loc(#loc11) - %11 = scf.if %10 -> (i32) { - %122 = arith.addi %8, %c1_i32 : i32 loc(#loc13) - scf.yield %122 : i32 loc(#loc13) - } else { - scf.yield %8 : i32 loc(#loc1) - } loc(#loc12) - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc14) - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc14) - %14 = arith.muli %4, %c8_i32 : i32 loc(#loc15) - %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc16) - %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc16) - %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc17) - %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc17) - %19 = arith.muli %6, %11 : i32 loc(#loc18) - %20 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc19) - %21 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc20) - %22 = arith.cmpi sgt, %19, %c0_i32 : i32 loc(#loc21) - %23 = arith.subi %6, %c1_i32 : i32 loc(#loc22) - %24 = arith.divsi %0, %14 : i32 loc(#loc23) - %25 = arith.muli %24, %c8_i32 : i32 loc(#loc24) - %26 = arith.subi %2, %25 : i32 loc(#loc25) - %27 = arith.minsi %26, %c8_i32 : i32 loc(#loc26) - %28 = arith.remsi %0, %27 : i32 loc(#loc27) - %29 = arith.addi %25, %28 : i32 loc(#loc28) - %30 = arith.remsi %0, %14 : i32 loc(#loc29) - %31 = arith.divsi %30, %27 : i32 loc(#loc30) - %32 = arith.muli %29, %c128_i32 : i32 loc(#loc31) - %33 = arith.muli %31, %c256_i32 : i32 loc(#loc32) - %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) - %35 = arith.addi %34, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) - %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) - %37 = arith.addi %36, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) - %38 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) - %39 = arith.cmpi slt, %35, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) - %40 = arith.select %39, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc36) - %41 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) - %42 = arith.cmpi slt, %37, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) - %43 = arith.select %42, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) - %44 = tt.expand_dims %40 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc39) - %45 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc40) - %46 = arith.muli %44, %45 : tensor<128x1xi32, #blocked1> loc(#loc40) - %47 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc41) - %48 = tt.broadcast %46 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) - %49 = tt.broadcast %47 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) - %50 = arith.addi %48, %49 : tensor<128x64xi32, #blocked1> loc(#loc42) - %51 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc43) - %52 = tt.addptr %51, %50 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc43) - %53 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc44) - %54 = tt.expand_dims %43 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc45) - %55 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> loc(#loc46) - %56 = arith.muli %54, %55 : tensor<1x256xi32, #blocked> loc(#loc46) - %57 = tt.broadcast %53 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) - %58 = tt.broadcast %56 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) - %59 = arith.addi %57, %58 : tensor<64x256xi32, #blocked> loc(#loc47) - %60 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> loc(#loc48) - %61 = tt.addptr %60, %59 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc48) - %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %63 = arith.cmpi slt, %47, %62 : tensor<1x64xi32, #blocked1> loc(#loc49) - %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc19) - %65 = ttg.memdesc_subview %20[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) - %66 = tt.splat %22 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc21) - %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> loc(#loc21) - %68 = ttg.async_copy_global_to_local %52, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) - %69 = ttg.async_commit_group %68 loc(#loc19) - %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %71 = arith.cmpi slt, %53, %70 : tensor<64x1xi32, #blocked> loc(#loc50) - %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc20) - %73 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) - %74 = tt.splat %22 : i1 -> tensor<64x256xi1, #blocked> loc(#loc21) - %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> loc(#loc21) - %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) - %77 = ttg.async_commit_group %76 loc(#loc20) - %78 = arith.cmpi sgt, %19, %c1_i32 : i32 loc(#loc21) - %79 = arith.cmpi ne, %23, %c0_i32 : i32 loc(#loc84) - %80 = arith.extui %79 : i1 to i32 loc(#loc51) - %81 = arith.cmpi eq, %80, %c0_i32 : i32 loc(#loc53) - %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %122 = arith.addi %0, %c132_i32 : i32 loc(#loc55) - %123 = arith.divsi %122, %14 : i32 loc(#loc23) - %124 = arith.muli %123, %c8_i32 : i32 loc(#loc24) - %125 = arith.subi %2, %124 : i32 loc(#loc25) - %126 = arith.minsi %125, %c8_i32 : i32 loc(#loc26) - %127 = arith.remsi %122, %126 : i32 loc(#loc27) - %128 = arith.addi %124, %127 : i32 loc(#loc28) - %129 = arith.remsi %122, %14 : i32 loc(#loc29) - %130 = arith.divsi %129, %126 : i32 loc(#loc30) - %131 = arith.muli %128, %c128_i32 : i32 loc(#loc31) - %132 = arith.muli %130, %c256_i32 : i32 loc(#loc32) - %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) - %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) - %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) - %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) - %137 = arith.cmpi slt, %134, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) - %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc36) - %139 = arith.cmpi slt, %136, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) - %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) - scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) - } else { - scf.yield %0, %29, %31, %40, %43 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - } loc(#loc54) - %83 = arith.muli %80, %c64_i32 : i32 loc(#loc56) - %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc39) - %89 = arith.muli %88, %45 : tensor<128x1xi32, #blocked1> loc(#loc40) - %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc41) - %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) - %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) - %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> loc(#loc42) - %94 = tt.addptr %51, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc43) - %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc44) - %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc45) - %97 = arith.muli %96, %55 : tensor<1x256xi32, #blocked> loc(#loc46) - %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) - %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) - %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> loc(#loc47) - %101 = tt.addptr %60, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc48) - %102 = arith.subi %arg5, %83 : i32 loc(#loc58) - %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %104 = arith.cmpi slt, %47, %103 : tensor<1x64xi32, #blocked1> loc(#loc49) - %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc19) - %106 = ttg.memdesc_subview %20[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) - %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc21) - %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> loc(#loc21) - %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) - %110 = ttg.async_commit_group %109 loc(#loc19) - %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %112 = arith.cmpi slt, %53, %111 : tensor<64x1xi32, #blocked> loc(#loc50) - %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc20) - %114 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) - %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> loc(#loc21) - %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> loc(#loc21) - %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) - %118 = ttg.async_commit_group %117 loc(#loc20) - %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %29, %arg25 = %82#1, %arg26 = %31, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { - %122 = arith.subi %19, %c2_i32 : i32 loc(#loc21) - %123 = arith.cmpi slt, %arg9, %122 : i32 loc(#loc21) - %124 = arith.cmpi eq, %arg10, %23 : i32 loc(#loc52) - %125 = arith.addi %arg10, %c1_i32 : i32 loc(#loc59) - %126 = arith.select %124, %c0_i32, %125 : i32 loc(#loc51) - %127 = arith.cmpi eq, %126, %c0_i32 : i32 loc(#loc53) - %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %178 = arith.addi %arg11, %c132_i32 : i32 loc(#loc55) - %179 = arith.divsi %178, %14 : i32 loc(#loc23) - %180 = arith.muli %179, %c8_i32 : i32 loc(#loc24) - %181 = arith.subi %2, %180 : i32 loc(#loc25) - %182 = arith.minsi %181, %c8_i32 : i32 loc(#loc26) - %183 = arith.remsi %178, %182 : i32 loc(#loc27) - %184 = arith.addi %180, %183 : i32 loc(#loc28) - %185 = arith.remsi %178, %14 : i32 loc(#loc29) - %186 = arith.divsi %185, %182 : i32 loc(#loc30) - %187 = arith.muli %184, %c128_i32 : i32 loc(#loc31) - %188 = arith.muli %186, %c256_i32 : i32 loc(#loc32) - %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) - %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc33) - %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) - %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc34) - %193 = arith.cmpi slt, %190, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc35) - %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc36) - %195 = arith.cmpi slt, %192, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc37) - %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) - scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc38) - } else { - scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc1) - } loc(#loc54) - %129 = arith.addi %arg19, %c1_i32 : i32 loc(#loc21) - %130 = arith.cmpi slt, %129, %c3_i32 : i32 loc(#loc21) - %131 = arith.select %130, %129, %c0_i32 : i32 loc(#loc21) - %132 = ttg.memdesc_subview %20[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) - %133 = ttg.async_wait %arg20 {num = 2 : i32} loc(#loc19) - %134 = ttg.memdesc_subview %21[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) - %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc60) - %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc60) - %137 = arith.cmpi ne, %arg22, %23 : i32 loc(#loc85) - %138 = arith.addi %arg18, %c1_i32 : i32 loc(#loc21) - %139 = arith.cmpi slt, %138, %c3_i32 : i32 loc(#loc21) - %140 = arith.select %139, %138, %c0_i32 : i32 loc(#loc21) - %141 = arith.muli %126, %c64_i32 : i32 loc(#loc56) - %142 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %143 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %144 = arith.addi %142, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc57) - %145 = arith.addi %143, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc57) - %146 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc39) - %147 = arith.muli %146, %45 : tensor<128x1xi32, #blocked1> loc(#loc40) - %148 = tt.expand_dims %144 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc41) - %149 = tt.broadcast %147 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) - %150 = tt.broadcast %148 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc42) - %151 = arith.addi %149, %150 : tensor<128x64xi32, #blocked1> loc(#loc42) - %152 = tt.addptr %51, %151 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc43) - %153 = tt.expand_dims %145 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc44) - %154 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> loc(#loc45) - %155 = arith.muli %154, %55 : tensor<1x256xi32, #blocked> loc(#loc46) - %156 = tt.broadcast %153 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) - %157 = tt.broadcast %155 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> loc(#loc47) - %158 = arith.addi %156, %157 : tensor<64x256xi32, #blocked> loc(#loc47) - %159 = tt.addptr %60, %158 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> loc(#loc48) - %160 = arith.subi %arg5, %141 : i32 loc(#loc58) - %161 = tt.splat %160 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc49) - %162 = arith.cmpi slt, %47, %161 : tensor<1x64xi32, #blocked1> loc(#loc49) - %163 = tt.broadcast %162 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc19) - %164 = ttg.memdesc_subview %20[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) - %165 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc21) - %166 = arith.andi %165, %163 : tensor<128x64xi1, #blocked1> loc(#loc21) - %167 = ttg.async_copy_global_to_local %152, %164 mask %166 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> loc(#loc19) - %168 = ttg.async_commit_group %167 loc(#loc19) - %169 = tt.splat %160 : i32 -> tensor<64x1xi32, #blocked> loc(#loc50) - %170 = arith.cmpi slt, %53, %169 : tensor<64x1xi32, #blocked> loc(#loc50) - %171 = tt.broadcast %170 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> loc(#loc20) - %172 = ttg.memdesc_subview %21[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) - %173 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> loc(#loc21) - %174 = arith.andi %173, %171 : tensor<64x256xi1, #blocked> loc(#loc21) - %175 = ttg.async_copy_global_to_local %159, %172 mask %174 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc20) - %176 = ttg.async_commit_group %175 loc(#loc20) - %177 = arith.cmpi eq, %arg22, %23 : i32 loc(#loc61) - scf.if %177 { - %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> loc(#loc60) - %179 = arith.muli %arg24, %c128_i32 : i32 loc(#loc63) - %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc64) - %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc64) - %182 = arith.muli %arg26, %c256_i32 : i32 loc(#loc65) - %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc66) - %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc66) - %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> loc(#loc67) - %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc68) - %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> loc(#loc68) - %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> loc(#loc69) - %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> loc(#loc69) - %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> loc(#loc70) - %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> loc(#loc71) - %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> loc(#loc71) - %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> loc(#loc71) - %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> loc(#loc72) - %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> loc(#loc72) - %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> loc(#loc73) - %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> loc(#loc73) - %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc74) - %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> loc(#loc74) - %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> loc(#loc74) - %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc75) - %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> loc(#loc76) - tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> loc(#loc76) - } loc(#loc62) - scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %137, %140, %131, %arg21, %176, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 loc(#loc21) - } loc(#loc21) - %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc21) - %121 = ttg.async_wait {num = 0 : i32} loc(#loc21) - ttg.local_dealloc %20 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> loc(#loc21) - ttg.local_dealloc %21 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> loc(#loc21) - tt.return loc(#loc77) - } loc(#loc) -} loc(#loc) -#loc1 = loc(unknown) -#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:30) -#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) -#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) -#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) -#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:27) -#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:25) -#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":171:28) -#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:32) -#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:31) -#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:19) -#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":174:7) -#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:24) -#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:35) -#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":182:38) -#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:27) -#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:27) -#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:32) -#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:20) -#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:20) -#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:22) -#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:38) -#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:34) -#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":196:37) -#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:43) -#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:56) -#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:45) -#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:35) -#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:31) -#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:52) -#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":201:30) -#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":202:30) -#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:32) -#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:32) -#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:41) -#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":205:53) -#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":206:41) -#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":206:53) -#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:34) -#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:45) -#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:64) -#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:57) -#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":210:26) -#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:33) -#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:64) -#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:75) -#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:56) -#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":211:26) -#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:60) -#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":214:60) -#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:44) -#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:28) -#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:17) -#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:11) -#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:23) -#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:22) -#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:37) -#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":213:64) -#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:49) -#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":215:35) -#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:17) -#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":217:11) -#loc63 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:30) -#loc64 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":218:45) -#loc65 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:30) -#loc66 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":219:45) -#loc67 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:49) -#loc68 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:41) -#loc69 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:29) -#loc70 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:80) -#loc71 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":220:60) -#loc72 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":221:41) -#loc73 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":221:66) -#loc74 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":221:47) -#loc75 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":225:35) -#loc76 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":226:29) -#loc77 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":191:4) -#loc78 = loc(callsite(#loc3 at #loc4)) -#loc79 = loc(callsite(#loc5 at #loc4)) -#loc80 = loc(callsite(#loc3 at #loc6)) -#loc81 = loc(callsite(#loc5 at #loc6)) -#loc82 = loc(callsite(#loc3 at #loc7)) -#loc83 = loc(callsite(#loc5 at #loc7)) -#loc84 = loc(fused[#loc51, #loc52]) -#loc85 = loc(fused[#loc60, #loc61]) - diff --git a/orig2.mlir b/orig2.mlir deleted file mode 100644 index 69b0e81760e4..000000000000 --- a/orig2.mlir +++ /dev/null @@ -1,293 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent_fused(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c2_i32 = arith.constant 2 : i32 - %c3_i32 = arith.constant 3 : i32 - %false = arith.constant false - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c0_i32 = arith.constant 0 : i32 - %c8_i32 = arith.constant 8 : i32 - %c-1_i32 = arith.constant -1 : i32 - %c1_i32 = arith.constant 1 : i32 - %c132_i32 = arith.constant 132 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.addi %arg5, %c63_i32 : i32 - %6 = arith.divsi %5, %c64_i32 : i32 - %7 = arith.muli %2, %4 : i32 - %8 = arith.divsi %7, %c132_i32 : i32 - %9 = arith.remsi %7, %c132_i32 : i32 - %10 = arith.cmpi slt, %0, %9 : i32 - %11 = scf.if %10 -> (i32) { - %122 = arith.addi %8, %c1_i32 : i32 - scf.yield %122 : i32 - } else { - scf.yield %8 : i32 - } - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %14 = arith.muli %4, %c8_i32 : i32 - %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %19 = arith.muli %6, %11 : i32 - %20 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - %21 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - %22 = arith.cmpi sgt, %19, %c0_i32 : i32 - %23 = arith.subi %6, %c1_i32 : i32 - %24 = arith.divsi %0, %14 : i32 - %25 = arith.muli %24, %c8_i32 : i32 - %26 = arith.subi %2, %25 : i32 - %27 = arith.minsi %26, %c8_i32 : i32 - %28 = arith.remsi %0, %27 : i32 - %29 = arith.addi %25, %28 : i32 - %30 = arith.remsi %0, %14 : i32 - %31 = arith.divsi %30, %27 : i32 - %32 = arith.muli %29, %c128_i32 : i32 - %33 = arith.muli %31, %c256_i32 : i32 - %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %35 = arith.addi %34, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %37 = arith.addi %36, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %38 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %39 = arith.cmpi slt, %35, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %40 = arith.select %39, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %41 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %42 = arith.cmpi slt, %37, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %43 = arith.select %42, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %44 = tt.expand_dims %40 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %45 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %46 = arith.muli %44, %45 : tensor<128x1xi32, #blocked1> - %47 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %48 = tt.broadcast %46 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %49 = tt.broadcast %47 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %50 = arith.addi %48, %49 : tensor<128x64xi32, #blocked1> - %51 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %52 = tt.addptr %51, %50 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %53 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %54 = tt.expand_dims %43 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %55 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %56 = arith.muli %54, %55 : tensor<1x256xi32, #blocked> - %57 = tt.broadcast %53 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %58 = tt.broadcast %56 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %59 = arith.addi %57, %58 : tensor<64x256xi32, #blocked> - %60 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %61 = tt.addptr %60, %59 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %62 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> - %63 = arith.cmpi slt, %47, %62 : tensor<1x64xi32, #blocked1> - %64 = tt.broadcast %63 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %65 = ttg.memdesc_subview %20[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %66 = tt.splat %22 : i1 -> tensor<128x64xi1, #blocked1> - %67 = arith.andi %66, %64 : tensor<128x64xi1, #blocked1> - %68 = ttg.async_copy_global_to_local %52, %65 mask %67 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %69 = ttg.async_commit_group %68 - %70 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> - %71 = arith.cmpi slt, %53, %70 : tensor<64x1xi32, #blocked> - %72 = tt.broadcast %71 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %73 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %74 = tt.splat %22 : i1 -> tensor<64x256xi1, #blocked> - %75 = arith.andi %74, %72 : tensor<64x256xi1, #blocked> - %76 = ttg.async_copy_global_to_local %61, %73 mask %75 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %77 = ttg.async_commit_group %76 - %78 = arith.cmpi sgt, %19, %c1_i32 : i32 - %79 = arith.cmpi ne, %23, %c0_i32 : i32 - %80 = arith.extui %79 : i1 to i32 - %81 = arith.cmpi eq, %80, %c0_i32 : i32 - %82:5 = scf.if %81 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %122 = arith.addi %0, %c132_i32 : i32 - %123 = arith.divsi %122, %14 : i32 - %124 = arith.muli %123, %c8_i32 : i32 - %125 = arith.subi %2, %124 : i32 - %126 = arith.minsi %125, %c8_i32 : i32 - %127 = arith.remsi %122, %126 : i32 - %128 = arith.addi %124, %127 : i32 - %129 = arith.remsi %122, %14 : i32 - %130 = arith.divsi %129, %126 : i32 - %131 = arith.muli %128, %c128_i32 : i32 - %132 = arith.muli %130, %c256_i32 : i32 - %133 = tt.splat %131 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %134 = arith.addi %133, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %135 = tt.splat %132 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %136 = arith.addi %135, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %137 = arith.cmpi slt, %134, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %138 = arith.select %137, %134, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %139 = arith.cmpi slt, %136, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %140 = arith.select %139, %136, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %122, %128, %130, %138, %140 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - } else { - scf.yield %0, %29, %31, %40, %43 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - } - %83 = arith.muli %80, %c64_i32 : i32 - %84 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %85 = tt.splat %83 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %86 = arith.addi %84, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %87 = arith.addi %85, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %88 = tt.expand_dims %82#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %89 = arith.muli %88, %45 : tensor<128x1xi32, #blocked1> - %90 = tt.expand_dims %86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %91 = tt.broadcast %89 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %92 = tt.broadcast %90 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %93 = arith.addi %91, %92 : tensor<128x64xi32, #blocked1> - %94 = tt.addptr %51, %93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %95 = tt.expand_dims %87 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %96 = tt.expand_dims %82#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %97 = arith.muli %96, %55 : tensor<1x256xi32, #blocked> - %98 = tt.broadcast %95 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %99 = tt.broadcast %97 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %100 = arith.addi %98, %99 : tensor<64x256xi32, #blocked> - %101 = tt.addptr %60, %100 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %102 = arith.subi %arg5, %83 : i32 - %103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> - %104 = arith.cmpi slt, %47, %103 : tensor<1x64xi32, #blocked1> - %105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %106 = ttg.memdesc_subview %20[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %107 = tt.splat %78 : i1 -> tensor<128x64xi1, #blocked1> - %108 = arith.andi %107, %105 : tensor<128x64xi1, #blocked1> - %109 = ttg.async_copy_global_to_local %94, %106 mask %108 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %110 = ttg.async_commit_group %109 - %111 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> - %112 = arith.cmpi slt, %53, %111 : tensor<64x1xi32, #blocked> - %113 = tt.broadcast %112 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %114 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %115 = tt.splat %78 : i1 -> tensor<64x256xi1, #blocked> - %116 = arith.andi %115, %113 : tensor<64x256xi1, #blocked> - %117 = ttg.async_copy_global_to_local %101, %114 mask %116 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %118 = ttg.async_commit_group %117 - %119:18 = scf.for %arg9 = %c0_i32 to %19 step %c1_i32 iter_args(%arg10 = %80, %arg11 = %82#0, %arg12 = %82#1, %arg13 = %82#2, %arg14 = %cst_3, %arg15 = %82#3, %arg16 = %82#4, %arg17 = %false, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %77, %arg21 = %118, %arg22 = %c0_i32, %arg23 = %80, %arg24 = %29, %arg25 = %82#1, %arg26 = %31, %arg27 = %82#2) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32) : i32 { - %122 = arith.subi %19, %c2_i32 : i32 - %123 = arith.cmpi slt, %arg9, %122 : i32 - %124 = arith.cmpi eq, %arg10, %23 : i32 - %125 = arith.addi %arg10, %c1_i32 : i32 - %126 = arith.select %124, %c0_i32, %125 : i32 - %127 = arith.cmpi eq, %126, %c0_i32 : i32 - %128:5 = scf.if %127 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { - %178 = arith.addi %arg11, %c132_i32 : i32 - %179 = arith.divsi %178, %14 : i32 - %180 = arith.muli %179, %c8_i32 : i32 - %181 = arith.subi %2, %180 : i32 - %182 = arith.minsi %181, %c8_i32 : i32 - %183 = arith.remsi %178, %182 : i32 - %184 = arith.addi %180, %183 : i32 - %185 = arith.remsi %178, %14 : i32 - %186 = arith.divsi %185, %182 : i32 - %187 = arith.muli %184, %c128_i32 : i32 - %188 = arith.muli %186, %c256_i32 : i32 - %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %190 = arith.addi %189, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %192 = arith.addi %191, %17 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %193 = arith.cmpi slt, %190, %38 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %195 = arith.cmpi slt, %192, %41 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %178, %184, %186, %194, %196 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - } else { - scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - } - %129 = arith.addi %arg19, %c1_i32 : i32 - %130 = arith.cmpi slt, %129, %c3_i32 : i32 - %131 = arith.select %130, %129, %c0_i32 : i32 - %132 = ttg.memdesc_subview %20[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %133 = ttg.async_wait %arg20 {num = 2 : i32} - %134 = ttg.memdesc_subview %21[%131, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %135 = ttng.warp_group_dot %132, %134, %arg14, %arg17 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> - %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %137 = arith.cmpi ne, %arg22, %23 : i32 - %138 = arith.addi %arg18, %c1_i32 : i32 - %139 = arith.cmpi slt, %138, %c3_i32 : i32 - %140 = arith.select %139, %138, %c0_i32 : i32 - %141 = arith.muli %126, %c64_i32 : i32 - %142 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %143 = tt.splat %141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %144 = arith.addi %142, %12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %145 = arith.addi %143, %13 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %146 = tt.expand_dims %128#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %147 = arith.muli %146, %45 : tensor<128x1xi32, #blocked1> - %148 = tt.expand_dims %144 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %149 = tt.broadcast %147 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %150 = tt.broadcast %148 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %151 = arith.addi %149, %150 : tensor<128x64xi32, #blocked1> - %152 = tt.addptr %51, %151 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %153 = tt.expand_dims %145 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %154 = tt.expand_dims %128#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %155 = arith.muli %154, %55 : tensor<1x256xi32, #blocked> - %156 = tt.broadcast %153 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %157 = tt.broadcast %155 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %158 = arith.addi %156, %157 : tensor<64x256xi32, #blocked> - %159 = tt.addptr %60, %158 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %160 = arith.subi %arg5, %141 : i32 - %161 = tt.splat %160 : i32 -> tensor<1x64xi32, #blocked1> - %162 = arith.cmpi slt, %47, %161 : tensor<1x64xi32, #blocked1> - %163 = tt.broadcast %162 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %164 = ttg.memdesc_subview %20[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %165 = tt.splat %123 : i1 -> tensor<128x64xi1, #blocked1> - %166 = arith.andi %165, %163 : tensor<128x64xi1, #blocked1> - %167 = ttg.async_copy_global_to_local %152, %164 mask %166 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %168 = ttg.async_commit_group %167 - %169 = tt.splat %160 : i32 -> tensor<64x1xi32, #blocked> - %170 = arith.cmpi slt, %53, %169 : tensor<64x1xi32, #blocked> - %171 = tt.broadcast %170 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %172 = ttg.memdesc_subview %21[%140, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %173 = tt.splat %123 : i1 -> tensor<64x256xi1, #blocked> - %174 = arith.andi %173, %171 : tensor<64x256xi1, #blocked> - %175 = ttg.async_copy_global_to_local %159, %172 mask %174 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %176 = ttg.async_commit_group %175 - %177 = arith.cmpi eq, %arg22, %23 : i32 - scf.if %177 { - %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %179 = arith.muli %arg24, %c128_i32 : i32 - %180 = tt.splat %179 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %181 = arith.addi %180, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %182 = arith.muli %arg26, %c256_i32 : i32 - %183 = tt.splat %182 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %184 = arith.addi %183, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %185 = tt.expand_dims %181 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %186 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %187 = arith.muli %186, %185 : tensor<128x1xi32, #blocked2> - %188 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %189 = tt.addptr %188, %187 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %190 = tt.expand_dims %184 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %191 = tt.broadcast %189 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %192 = tt.broadcast %190 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %193 = tt.addptr %191, %192 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %194 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %195 = arith.cmpi slt, %185, %194 : tensor<128x1xi32, #blocked2> - %196 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %197 = arith.cmpi slt, %190, %196 : tensor<1x256xi32, #blocked2> - %198 = tt.broadcast %195 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %199 = tt.broadcast %197 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %200 = arith.andi %198, %199 : tensor<128x256xi1, #blocked2> - %201 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %202 = ttg.convert_layout %201 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %193, %202, %200 : tensor<128x256x!tt.ptr, #blocked2> - } - scf.yield %126, %128#0, %128#1, %128#2, %136#0, %128#3, %128#4, %137, %140, %131, %arg21, %176, %arg23, %126, %arg25, %128#1, %arg27, %128#2 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i1, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, i32, i32 - } - %120 = ttng.warp_group_dot_wait %119#4 {pendings = 0 : i32} : tensor<128x256xf32, #mma> - %121 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %20 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - ttg.local_dealloc %21 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - tt.return - } -} - diff --git a/test.mlir b/test.mlir deleted file mode 100644 index be0afa4cddd7..000000000000 --- a/test.mlir +++ /dev/null @@ -1,178 +0,0 @@ -#loc = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0) -module { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":155:0)) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32> loc(#loc1) - %c63_i32 = arith.constant 63 : i32 loc(#loc1) - %c255_i32 = arith.constant 255 : i32 loc(#loc1) - %c127_i32 = arith.constant 127 : i32 loc(#loc1) - %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16> loc(#loc1) - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16> loc(#loc1) - %c1_i32 = arith.constant 1 : i32 loc(#loc1) - %c0_i32 = arith.constant 0 : i32 loc(#loc1) - %c132_i32 = arith.constant 132 : i32 loc(#loc1) - %c64_i32 = arith.constant 64 : i32 loc(#loc1) - %cst_2 = arith.constant dense<0> : tensor<256xi32> loc(#loc1) - %cst_3 = arith.constant dense<0> : tensor<128xi32> loc(#loc1) - %c256_i32 = arith.constant 256 : i32 loc(#loc1) - %c128_i32 = arith.constant 128 : i32 loc(#loc1) - %c8_i32 = arith.constant 8 : i32 loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc63) - %2 = arith.divsi %1, %c128_i32 : i32 loc(#loc64) - %3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc65) - %4 = arith.divsi %3, %c256_i32 : i32 loc(#loc66) - %5 = arith.addi %arg5, %c63_i32 : i32 loc(#loc67) - %6 = arith.divsi %5, %c64_i32 : i32 loc(#loc68) - %7 = arith.muli %2, %4 : i32 loc(#loc8) - %8 = arith.muli %4, %c8_i32 : i32 loc(#loc9) - %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc10) - %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc11) - %11 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc12) - %12 = tt.splat %arg3 : i32 -> tensor<128xi32> loc(#loc13) - %13 = tt.splat %arg4 : i32 -> tensor<256xi32> loc(#loc14) - %14 = tt.splat %arg6 : i32 -> tensor<128x1xi32> loc(#loc15) - %15 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> loc(#loc16) - %16 = tt.splat %arg7 : i32 -> tensor<1x256xi32> loc(#loc17) - %17 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> loc(#loc18) - %18 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc19) - %19 = tt.expand_dims %9 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc20) - %20 = tt.splat %arg8 : i32 -> tensor<128x1xi32> loc(#loc21) - %21 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc22) - %22 = tt.splat %arg3 : i32 -> tensor<128x1xi32> loc(#loc23) - %23 = tt.splat %arg4 : i32 -> tensor<1x256xi32> loc(#loc24) - scf.for %arg9 = %0 to %7 step %c132_i32 : i32 { - %24 = arith.divsi %arg9, %8 : i32 loc(#loc26) - %25 = arith.muli %24, %c8_i32 : i32 loc(#loc27) - %26 = arith.subi %2, %25 : i32 loc(#loc28) - %27 = arith.minsi %26, %c8_i32 : i32 loc(#loc29) - %28 = arith.remsi %arg9, %27 : i32 loc(#loc30) - %29 = arith.addi %25, %28 : i32 loc(#loc31) - %30 = arith.remsi %arg9, %8 : i32 loc(#loc32) - %31 = arith.divsi %30, %27 : i32 loc(#loc33) - %32 = arith.muli %29, %c128_i32 : i32 loc(#loc34) - %33 = arith.muli %31, %c256_i32 : i32 loc(#loc35) - %34 = tt.splat %32 : i32 -> tensor<128xi32> loc(#loc36) - %35 = arith.addi %34, %10 : tensor<128xi32> loc(#loc36) - %36 = tt.splat %33 : i32 -> tensor<256xi32> loc(#loc37) - %37 = arith.addi %36, %11 : tensor<256xi32> loc(#loc37) - %38 = arith.cmpi slt, %35, %12 : tensor<128xi32> loc(#loc13) - %39 = arith.select %38, %35, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1>, tensor<128xi32> loc(#loc38) - %40 = arith.cmpi slt, %37, %13 : tensor<256xi32> loc(#loc14) - %41 = arith.select %40, %37, %cst_2 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1>, tensor<256xi32> loc(#loc39) - %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc40) - %43 = arith.muli %42, %14 : tensor<128x1xi32> loc(#loc15) - %44 = tt.broadcast %43 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc41) - %45 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc42) - %46 = arith.muli %45, %16 : tensor<1x256xi32> loc(#loc17) - %47 = tt.broadcast %46 : tensor<1x256xi32> -> tensor<64x256xi32> loc(#loc43) - %48 = scf.for %arg10 = %c0_i32 to %6 step %c1_i32 iter_args(%arg11 = %cst) -> (tensor<128x256xf32>) : i32 { - %62 = arith.muli %arg10, %c64_i32 : i32 loc(#loc45) - %63 = tt.splat %62 : i32 -> tensor<64xi32> loc(#loc46) - %64 = arith.addi %63, %9 : tensor<64xi32> loc(#loc46) - %65 = tt.expand_dims %64 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc47) - %66 = tt.broadcast %65 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc41) - %67 = arith.addi %44, %66 : tensor<128x64xi32> loc(#loc41) - %68 = tt.addptr %15, %67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc16) - %69 = tt.expand_dims %64 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc48) - %70 = tt.broadcast %69 : tensor<64x1xi32> -> tensor<64x256xi32> loc(#loc43) - %71 = arith.addi %70, %47 : tensor<64x256xi32> loc(#loc43) - %72 = tt.addptr %17, %71 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> loc(#loc18) - %73 = arith.subi %arg5, %62 : i32 loc(#loc49) - %74 = tt.splat %73 : i32 -> tensor<1x64xi32> loc(#loc50) - %75 = arith.cmpi slt, %18, %74 : tensor<1x64xi32> loc(#loc50) - %76 = tt.broadcast %75 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc51) - %77 = tt.load %68, %76, %cst_1 : tensor<128x64x!tt.ptr> loc(#loc51) - %78 = tt.splat %73 : i32 -> tensor<64x1xi32> loc(#loc52) - %79 = arith.cmpi slt, %19, %78 : tensor<64x1xi32> loc(#loc52) - %80 = tt.broadcast %79 : tensor<64x1xi1> -> tensor<64x256xi1> loc(#loc53) - %81 = tt.load %72, %80, %cst_0 : tensor<64x256x!tt.ptr> loc(#loc53) - %82 = tt.dot %77, %81, %arg11, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x256xf16> -> tensor<128x256xf32> loc(#loc54) - scf.yield %82 : tensor<128x256xf32> loc(#loc55) - } loc(#loc44) - %49 = tt.expand_dims %35 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc56) - %50 = arith.muli %20, %49 : tensor<128x1xi32> loc(#loc21) - %51 = tt.addptr %21, %50 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc22) - %52 = tt.expand_dims %37 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc57) - %53 = tt.broadcast %51 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> loc(#loc58) - %54 = tt.broadcast %52 : tensor<1x256xi32> -> tensor<128x256xi32> loc(#loc58) - %55 = tt.addptr %53, %54 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> loc(#loc58) - %56 = arith.cmpi slt, %49, %22 : tensor<128x1xi32> loc(#loc23) - %57 = arith.cmpi slt, %52, %23 : tensor<1x256xi32> loc(#loc24) - %58 = tt.broadcast %56 : tensor<128x1xi1> -> tensor<128x256xi1> loc(#loc59) - %59 = tt.broadcast %57 : tensor<1x256xi1> -> tensor<128x256xi1> loc(#loc59) - %60 = arith.andi %58, %59 : tensor<128x256xi1> loc(#loc59) - %61 = arith.truncf %48 : tensor<128x256xf32> to tensor<128x256xf16> loc(#loc60) - tt.store %55, %61, %60 : tensor<128x256x!tt.ptr> loc(#loc61) - } loc(#loc25) - tt.return loc(#loc62) - } loc(#loc) -} loc(#loc) -#loc1 = loc(unknown) -#loc2 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":166:30) -#loc3 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:22) -#loc4 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":167:27) -#loc5 = loc("/root/.pyenv/versions/3.11.8/lib/python3.11/site-packages/triton/language/standard.py":40:28) -#loc6 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":168:27) -#loc7 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":169:25) -#loc8 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":170:28) -#loc9 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":171:38) -#loc10 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":173:35) -#loc11 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:41) -#loc12 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:41) -#loc13 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:37) -#loc14 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:37) -#loc15 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:49) -#loc16 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:30) -#loc17 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:79) -#loc18 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:30) -#loc19 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:53) -#loc20 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:53) -#loc21 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:37) -#loc22 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:25) -#loc23 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:37) -#loc24 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:62) -#loc25 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:47) -#loc26 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":176:30) -#loc27 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":177:33) -#loc28 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:39) -#loc29 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":178:52) -#loc30 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:41) -#loc31 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":179:31) -#loc32 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:27) -#loc33 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":180:48) -#loc34 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":182:26) -#loc35 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":183:26) -#loc36 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":184:28) -#loc37 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":185:28) -#loc38 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":186:49) -#loc39 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":187:49) -#loc40 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:38) -#loc41 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:61) -#loc42 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:68) -#loc43 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:60) -#loc44 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":192:24) -#loc45 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:26) -#loc46 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":193:41) -#loc47 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":194:68) -#loc48 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":195:37) -#loc49 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:68) -#loc50 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:64) -#loc51 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":197:24) -#loc52 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:64) -#loc53 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":198:24) -#loc54 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:39) -#loc55 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":199:12) -#loc56 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:45) -#loc57 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:76) -#loc58 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":203:56) -#loc59 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":204:43) -#loc60 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":208:31) -#loc61 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":209:25) -#loc62 = loc("/root/code/triton/python/tutorials/09-persistent-matmul.py":175:4) -#loc63 = loc(callsite(#loc3 at #loc4)) -#loc64 = loc(callsite(#loc5 at #loc4)) -#loc65 = loc(callsite(#loc3 at #loc6)) -#loc66 = loc(callsite(#loc5 at #loc6)) -#loc67 = loc(callsite(#loc3 at #loc7)) -#loc68 = loc(callsite(#loc5 at #loc7)) - diff --git a/test2.mlir b/test2.mlir deleted file mode 100644 index 425c15288bee..000000000000 --- a/test2.mlir +++ /dev/null @@ -1,128 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.addi %arg5, %c63_i32 : i32 - %6 = arith.divsi %5, %c64_i32 : i32 - %7 = arith.muli %2, %4 : i32 - %8 = arith.muli %4, %c8_i32 : i32 - %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %14 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %15 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %16 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %17 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %18 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %19 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %20 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %21 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %22 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %23 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %24 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %25 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %26 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - scf.for %arg9 = %0 to %7 step %c132_i32 : i32 { - %27 = arith.divsi %arg9, %8 : i32 - %28 = arith.muli %27, %c8_i32 : i32 - %29 = arith.subi %2, %28 : i32 - %30 = arith.minsi %29, %c8_i32 : i32 - %31 = arith.remsi %arg9, %30 : i32 - %32 = arith.addi %28, %31 : i32 - %33 = arith.remsi %arg9, %8 : i32 - %34 = arith.divsi %33, %30 : i32 - %35 = arith.muli %32, %c128_i32 : i32 - %36 = arith.muli %34, %c256_i32 : i32 - %37 = tt.splat %35 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %38 = tt.splat %35 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %39 = arith.addi %37, %11 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %40 = arith.addi %38, %12 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %41 = tt.splat %36 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %42 = tt.splat %36 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %43 = arith.addi %41, %13 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %44 = arith.addi %42, %14 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %45 = arith.cmpi slt, %39, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %46 = arith.select %45, %39, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %47 = arith.cmpi slt, %43, %16 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %48 = arith.select %47, %43, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %49 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %50 = arith.muli %49, %17 : tensor<128x1xi32, #blocked1> - %51 = tt.broadcast %50 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %52 = tt.expand_dims %48 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %53 = arith.muli %52, %19 : tensor<1x256xi32, #blocked> - %54 = tt.broadcast %53 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %55 = scf.for %arg10 = %c0_i32 to %6 step %c1_i32 iter_args(%arg11 = %cst_3) -> (tensor<128x256xf32, #mma>) : i32 { - %70 = arith.muli %arg10, %c64_i32 : i32 - %71 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %72 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %73 = arith.addi %71, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %74 = arith.addi %72, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %75 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %76 = tt.broadcast %75 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %77 = arith.addi %51, %76 : tensor<128x64xi32, #blocked1> - %78 = tt.addptr %18, %77 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %79 = tt.expand_dims %74 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %80 = tt.broadcast %79 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %81 = arith.addi %80, %54 : tensor<64x256xi32, #blocked> - %82 = tt.addptr %20, %81 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %83 = arith.subi %arg5, %70 : i32 - %84 = tt.splat %83 : i32 -> tensor<1x64xi32, #blocked1> - %85 = arith.cmpi slt, %21, %84 : tensor<1x64xi32, #blocked1> - %86 = tt.broadcast %85 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %87 = tt.load %78, %86, %cst_1 : tensor<128x64x!tt.ptr, #blocked1> - %88 = ttg.local_alloc %87 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %89 = tt.splat %83 : i32 -> tensor<64x1xi32, #blocked> - %90 = arith.cmpi slt, %22, %89 : tensor<64x1xi32, #blocked> - %91 = tt.broadcast %90 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %92 = tt.load %82, %91, %cst_2 : tensor<64x256x!tt.ptr, #blocked> - %93 = ttg.local_alloc %92 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> - %94 = ttng.warp_group_dot %88, %93, %arg11 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> - scf.yield %94 : tensor<128x256xf32, #mma> - } - %56 = tt.expand_dims %40 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %57 = arith.muli %23, %56 : tensor<128x1xi32, #blocked2> - %58 = tt.addptr %24, %57 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %59 = tt.expand_dims %44 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %60 = tt.broadcast %58 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %61 = tt.broadcast %59 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %62 = tt.addptr %60, %61 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %63 = arith.cmpi slt, %56, %25 : tensor<128x1xi32, #blocked2> - %64 = arith.cmpi slt, %59, %26 : tensor<1x256xi32, #blocked2> - %65 = tt.broadcast %63 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %66 = tt.broadcast %64 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %67 = arith.andi %65, %66 : tensor<128x256xi1, #blocked2> - %68 = arith.truncf %55 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %69 = ttg.convert_layout %68 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %62, %69, %67 : tensor<128x256x!tt.ptr, #blocked2> - } - tt.return - } -} - diff --git a/test3.mlir b/test3.mlir deleted file mode 100644 index a81b31ca4aa2..000000000000 --- a/test3.mlir +++ /dev/null @@ -1,177 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.addi %arg5, %c63_i32 : i32 - %6 = arith.divsi %5, %c64_i32 : i32 - %7 = arith.muli %2, %4 : i32 - %8 = arith.muli %4, %c8_i32 : i32 - %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %14 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %15 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %16 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %17 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %18 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %19 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %20 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %21 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %22 = tt.expand_dims %10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %23 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %24 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %25 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %26 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %27 = arith.subi %7, %0 : i32 - %28 = arith.ceildivsi %27, %c132_i32 : i32 - %29 = arith.subi %6, %c0_i32 : i32 - %30 = arith.ceildivsi %29, %c1_i32 : i32 - %c0_i64 = arith.constant 0 : i64 - %31 = arith.extsi %30 : i32 to i64 - %c1_i64 = arith.constant 1 : i64 - %32 = arith.maxsi %c1_i64, %31 : i64 - %33 = arith.addi %c0_i64, %32 : i64 - %c0_i64_4 = arith.constant 0 : i64 - %34 = arith.subi %33, %c0_i64_4 : i64 - %35 = arith.extsi %28 : i32 to i64 - %36 = arith.muli %35, %34 : i64 - %c-1_i64 = arith.constant -1 : i64 - %37 = arith.subi %0, %c132_i32 : i32 - %38 = ub.poison : i32 - %39 = ub.poison : tensor<128x256xf32, #mma> - %40 = ub.poison : i32 - %41 = ub.poison : i32 - %c0_i64_5 = arith.constant 0 : i64 - %c1_i64_6 = arith.constant 1 : i64 - %42:6 = scf.for %arg9 = %c0_i64_5 to %36 step %c1_i64_6 iter_args(%arg10 = %c-1_i64, %arg11 = %37, %arg12 = %38, %arg13 = %39, %arg14 = %40, %arg15 = %41) -> (i64, i32, i32, tensor<128x256xf32, #mma>, i32, i32) : i64 { - %c1_i64_7 = arith.constant 1 : i64 - %43 = arith.addi %arg10, %c1_i64_7 : i64 - %44 = arith.remsi %43, %34 : i64 - %c0_i64_8 = arith.constant 0 : i64 - %45 = arith.subi %c0_i64, %c0_i64_8 : i64 - %46 = arith.cmpi eq, %44, %45 : i64 - %47:5 = scf.if %46 -> (i32, i32, i32, tensor<128x256xf32, #mma>, i32) { - %56 = arith.addi %arg11, %c132_i32 : i32 - %57 = arith.divsi %56, %8 : i32 - %58 = arith.muli %57, %c8_i32 : i32 - %59 = arith.subi %2, %58 : i32 - %60 = arith.minsi %59, %c8_i32 : i32 - %61 = arith.remsi %56, %60 : i32 - %62 = arith.addi %58, %61 : i32 - %63 = arith.remsi %56, %8 : i32 - %64 = arith.divsi %63, %60 : i32 - %65 = arith.muli %62, %c128_i32 : i32 - %66 = arith.muli %64, %c256_i32 : i32 - scf.yield %c0_i32, %65, %66, %cst_3, %56 : i32, i32, i32, tensor<128x256xf32, #mma>, i32 - } else { - scf.yield %arg12, %arg14, %arg15, %arg13, %arg11 : i32, i32, i32, tensor<128x256xf32, #mma>, i32 - } - %48 = arith.extsi %30 : i32 to i64 - %49 = arith.addi %45, %48 : i64 - %50 = arith.cmpi sge, %44, %45 : i64 - %51 = arith.cmpi slt, %44, %49 : i64 - %52 = arith.andi %50, %51 : i1 - %true = arith.constant true - %53:2 = scf.if %true -> (i32, tensor<128x256xf32, #mma>) { - %56 = tt.splat %47#1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %57 = arith.addi %56, %11 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %58 = tt.splat %47#2 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %59 = arith.addi %58, %13 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %60 = arith.cmpi slt, %57, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %61 = arith.select %60, %57, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %62 = arith.cmpi slt, %59, %16 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %63 = arith.select %62, %59, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %64 = tt.expand_dims %61 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %65 = arith.muli %64, %17 : tensor<128x1xi32, #blocked1> - %66 = tt.broadcast %65 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %67 = tt.expand_dims %63 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %68 = arith.muli %67, %19 : tensor<1x256xi32, #blocked> - %69 = tt.broadcast %68 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %70 = arith.muli %47#0, %c64_i32 : i32 - %71 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %72 = tt.splat %70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %73 = arith.addi %71, %9 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %74 = arith.addi %72, %10 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %75 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %76 = tt.broadcast %75 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %77 = arith.addi %66, %76 : tensor<128x64xi32, #blocked1> - %78 = tt.addptr %18, %77 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %79 = tt.expand_dims %74 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %80 = tt.broadcast %79 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %81 = arith.addi %80, %69 : tensor<64x256xi32, #blocked> - %82 = tt.addptr %20, %81 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %83 = arith.subi %arg5, %70 : i32 - %84 = tt.splat %83 : i32 -> tensor<1x64xi32, #blocked1> - %85 = arith.cmpi slt, %21, %84 : tensor<1x64xi32, #blocked1> - %86 = tt.broadcast %85 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %87 = tt.load %78, %86, %cst_1 : tensor<128x64x!tt.ptr, #blocked1> - %88 = ttg.local_alloc %87 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %89 = tt.splat %83 : i32 -> tensor<64x1xi32, #blocked> - %90 = arith.cmpi slt, %22, %89 : tensor<64x1xi32, #blocked> - %91 = tt.broadcast %90 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %92 = tt.load %82, %91, %cst_2 : tensor<64x256x!tt.ptr, #blocked> - %93 = ttg.local_alloc %92 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> - %94 = ttng.warp_group_dot %88, %93, %47#3 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> - %95 = arith.addi %47#0, %c1_i32 : i32 - scf.yield %95, %94 : i32, tensor<128x256xf32, #mma> - } else { - scf.yield %47#0, %arg13 : i32, tensor<128x256xf32, #mma> - } - %c1_i64_9 = arith.constant 1 : i64 - %54 = arith.subi %34, %c1_i64_9 : i64 - %55 = arith.cmpi eq, %44, %54 : i64 - scf.if %55 { - %56 = tt.splat %47#1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %57 = arith.addi %56, %12 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %58 = tt.splat %47#2 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %59 = arith.addi %58, %14 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %60 = tt.expand_dims %57 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %61 = arith.muli %23, %60 : tensor<128x1xi32, #blocked2> - %62 = tt.addptr %24, %61 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %63 = tt.expand_dims %59 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %64 = tt.broadcast %62 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %65 = tt.broadcast %63 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %66 = tt.addptr %64, %65 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %67 = arith.cmpi slt, %60, %25 : tensor<128x1xi32, #blocked2> - %68 = arith.cmpi slt, %63, %26 : tensor<1x256xi32, #blocked2> - %69 = tt.broadcast %67 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %70 = tt.broadcast %68 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %71 = arith.andi %69, %70 : tensor<128x256xi1, #blocked2> - %72 = arith.truncf %53#1 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %73 = ttg.convert_layout %72 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %66, %73, %71 : tensor<128x256x!tt.ptr, #blocked2> - } else { - } - scf.yield %44, %47#4, %53#0, %53#1, %47#1, %47#2 : i64, i32, i32, tensor<128x256xf32, #mma>, i32, i32 - } - tt.return - } -} - diff --git a/test4.mlir b/test4.mlir deleted file mode 100644 index 01d3e533847d..000000000000 --- a/test4.mlir +++ /dev/null @@ -1,192 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked> - %0 = ub.poison : tensor<64x256xi32, #blocked1> - %1 = ub.poison : tensor<128x64xi32, #blocked2> - %2 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %3 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %4 = ub.poison : tensor<128x256xf32, #mma> - %5 = ub.poison : i32 - %c-1_i64 = arith.constant -1 : i64 - %c1_i64 = arith.constant 1 : i64 - %c0_i64 = arith.constant 0 : i64 - %cst_0 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %cst_1 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked2> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked1> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %true = arith.constant true - %false = arith.constant false - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %6 = tt.get_program_id x : i32 - %7 = arith.addi %arg3, %c127_i32 : i32 - %8 = arith.divsi %7, %c128_i32 : i32 - %9 = arith.addi %arg4, %c255_i32 : i32 - %10 = arith.divsi %9, %c256_i32 : i32 - %11 = arith.addi %arg5, %c63_i32 : i32 - %12 = arith.divsi %11, %c64_i32 : i32 - %13 = arith.muli %8, %10 : i32 - %14 = arith.muli %10, %c8_i32 : i32 - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %21 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %22 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %23 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked2> - %24 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2> - %25 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked1> - %26 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> - %27 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> - %28 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %29 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked> - %30 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %31 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked> - %32 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked> - %33 = arith.cmpi eq, %12, %c0_i32 : i32 - scf.if %33 { - scf.for %arg9 = %6 to %13 step %c132_i32 : i32 { - %34 = arith.divsi %arg9, %14 : i32 - %35 = arith.muli %34, %c8_i32 : i32 - %36 = arith.subi %8, %35 : i32 - %37 = arith.minsi %36, %c8_i32 : i32 - %38 = arith.remsi %arg9, %37 : i32 - %39 = arith.addi %35, %38 : i32 - %40 = arith.remsi %arg9, %14 : i32 - %41 = arith.divsi %40, %37 : i32 - %42 = arith.muli %39, %c128_i32 : i32 - %43 = arith.muli %41, %c256_i32 : i32 - %44 = tt.splat %42 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %45 = arith.addi %44, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %46 = tt.splat %43 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %47 = arith.addi %46, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %48 = tt.expand_dims %45 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %49 = arith.muli %29, %48 : tensor<128x1xi32, #blocked> - %50 = tt.addptr %30, %49 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %51 = tt.expand_dims %47 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %52 = tt.broadcast %50 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> - %53 = tt.broadcast %51 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> - %54 = tt.addptr %52, %53 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> - %55 = arith.cmpi slt, %48, %31 : tensor<128x1xi32, #blocked> - %56 = arith.cmpi slt, %51, %32 : tensor<1x256xi32, #blocked> - %57 = tt.broadcast %55 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> - %58 = tt.broadcast %56 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> - %59 = arith.andi %57, %58 : tensor<128x256xi1, #blocked> - tt.store %54, %cst, %59 : tensor<128x256x!tt.ptr, #blocked> - } - } else { - %34 = arith.subi %13, %6 : i32 - %35 = arith.ceildivsi %34, %c132_i32 : i32 - %36 = arith.extsi %12 : i32 to i64 - %37 = arith.maxsi %36, %c1_i64 : i64 - %38 = arith.extsi %35 : i32 to i64 - %39 = arith.muli %38, %37 : i64 - %40 = arith.subi %6, %c132_i32 : i32 - %41:9 = scf.for %arg9 = %c0_i64 to %39 step %c1_i64 iter_args(%arg10 = %c-1_i64, %arg11 = %40, %arg12 = %5, %arg13 = %4, %arg14 = %3, %arg15 = %2, %arg16 = %1, %arg17 = %0, %arg18 = %true) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i1) : i64 { - %42 = arith.addi %arg10, %c1_i64 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i64 - %43 = arith.remsi %42, %37 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i64 - %44 = arith.cmpi eq, %43, %c0_i64 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i64 - %45:7 = scf.if %44 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1) { - %74 = arith.addi %arg11, %c132_i32 : i32 - %75 = arith.divsi %74, %14 : i32 - %76 = arith.muli %75, %c8_i32 : i32 - %77 = arith.subi %8, %76 : i32 - %78 = arith.minsi %77, %c8_i32 : i32 - %79 = arith.remsi %74, %78 : i32 - %80 = arith.addi %76, %79 : i32 - %81 = arith.remsi %74, %14 : i32 - %82 = arith.divsi %81, %78 : i32 - %83 = arith.muli %80, %c128_i32 : i32 - %84 = arith.muli %82, %c256_i32 : i32 - %85 = tt.splat %83 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %86 = tt.splat %83 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %87 = arith.addi %85, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %88 = arith.addi %86, %18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %89 = tt.splat %84 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %90 = tt.splat %84 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %91 = arith.addi %89, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %92 = arith.addi %90, %20 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %93 = arith.cmpi slt, %87, %21 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %94 = arith.select %93, %87, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %95 = arith.cmpi slt, %91, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %96 = arith.select %95, %91, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %97 = tt.expand_dims %94 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %98 = arith.muli %97, %23 : tensor<128x1xi32, #blocked2> - %99 = tt.broadcast %98 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %100 = tt.expand_dims %96 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> - %101 = arith.muli %100, %25 : tensor<1x256xi32, #blocked1> - %102 = tt.broadcast %101 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - scf.yield %88, %92, %99, %102, %74, %c0_i32, %false : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1 - } else { - scf.yield %arg14, %arg15, %arg16, %arg17, %arg11, %arg12, %arg18 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1 - } {loop.cluster = 1 : i32, loop.stage = 0 : i32} - %46 = arith.muli %45#5, %c64_i32 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 - %47 = tt.splat %46 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %48 = tt.splat %46 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %49 = arith.addi %47, %15 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %50 = arith.addi %48, %16 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %51 = tt.expand_dims %49 {axis = 0 : i32, loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> - %52 = tt.broadcast %51 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %53 = arith.addi %45#2, %52 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x64xi32, #blocked2> - %54 = tt.addptr %24, %53 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %55 = tt.expand_dims %50 {axis = 1 : i32, loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %56 = tt.broadcast %55 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - %57 = arith.addi %56, %45#3 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x256xi32, #blocked1> - %58 = tt.addptr %26, %57 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> - %59 = arith.subi %arg5, %46 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 - %60 = tt.splat %59 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<1x64xi32, #blocked2> - %61 = arith.cmpi slt, %27, %60 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<1x64xi32, #blocked2> - %62 = tt.broadcast %61 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> - %63 = tt.load %54, %62, %cst_2 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr, #blocked2> - %64 = ttg.local_alloc %63 {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %65 = tt.splat %59 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 -> tensor<64x1xi32, #blocked1> - %66 = arith.cmpi slt, %28, %65 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x1xi32, #blocked1> - %67 = tt.broadcast %66 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> - %68 = tt.load %58, %67, %cst_3 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<64x256x!tt.ptr, #blocked1> - %69 = ttg.local_alloc %68 {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared1, #smem> - %70 = ttng.warp_group_dot %64, %69, %arg13, %45#6 {inputPrecision = 0 : i32, loop.cluster = 2 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> - %71 = arith.addi %45#5, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : i32 - %72 = arith.subi %37, %c1_i64 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i64 - %73 = arith.cmpi eq, %43, %72 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i64 - scf.if %73 { - %74 = tt.expand_dims %45#0 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %75 = arith.muli %29, %74 : tensor<128x1xi32, #blocked> - %76 = tt.addptr %30, %75 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %77 = tt.expand_dims %45#1 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %78 = tt.broadcast %76 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> - %79 = tt.broadcast %77 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> - %80 = tt.addptr %78, %79 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> - %81 = arith.cmpi slt, %74, %31 : tensor<128x1xi32, #blocked> - %82 = arith.cmpi slt, %77, %32 : tensor<1x256xi32, #blocked> - %83 = tt.broadcast %81 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> - %84 = tt.broadcast %82 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> - %85 = arith.andi %83, %84 : tensor<128x256xi1, #blocked> - %86 = arith.truncf %70 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %87 = ttg.convert_layout %86 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> - tt.store %80, %87, %85 : tensor<128x256x!tt.ptr, #blocked> - } {loop.cluster = 5 : i32, loop.stage = 2 : i32} - scf.yield %43, %45#4, %71, %70, %45#0, %45#1, %45#2, %45#3, %true : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i1 - } - } - tt.return - } -} - diff --git a/test5.mlir b/test5.mlir deleted file mode 100644 index 07d0108d2182..000000000000 --- a/test5.mlir +++ /dev/null @@ -1,345 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c2_i64 = arith.constant 2 : i64 - %c3_i32 = arith.constant 3 : i32 - %c-1_i32 = arith.constant -1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #blocked> - %0 = ub.poison : tensor<64x256xi32, #blocked1> - %1 = ub.poison : tensor<128x64xi32, #blocked2> - %2 = ub.poison : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %3 = ub.poison : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %4 = ub.poison : tensor<128x256xf32, #mma> - %c1_i64 = arith.constant 1 : i64 - %c0_i64 = arith.constant 0 : i64 - %cst_0 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %cst_1 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked2> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked1> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %5 = tt.get_program_id x : i32 - %6 = arith.addi %arg3, %c127_i32 : i32 - %7 = arith.divsi %6, %c128_i32 : i32 - %8 = arith.addi %arg4, %c255_i32 : i32 - %9 = arith.divsi %8, %c256_i32 : i32 - %10 = arith.addi %arg5, %c63_i32 : i32 - %11 = arith.divsi %10, %c64_i32 : i32 - %12 = arith.muli %7, %9 : i32 - %13 = arith.muli %9, %c8_i32 : i32 - %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %20 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %21 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %22 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked2> - %23 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2> - %24 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked1> - %25 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> - %26 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> - %27 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %28 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked> - %29 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %30 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked> - %31 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked> - %32 = arith.cmpi eq, %11, %c0_i32 : i32 - scf.if %32 { - scf.for %arg9 = %5 to %12 step %c132_i32 : i32 { - %33 = arith.divsi %arg9, %13 : i32 - %34 = arith.muli %33, %c8_i32 : i32 - %35 = arith.subi %7, %34 : i32 - %36 = arith.minsi %35, %c8_i32 : i32 - %37 = arith.remsi %arg9, %36 : i32 - %38 = arith.addi %34, %37 : i32 - %39 = arith.remsi %arg9, %13 : i32 - %40 = arith.divsi %39, %36 : i32 - %41 = arith.muli %38, %c128_i32 : i32 - %42 = arith.muli %40, %c256_i32 : i32 - %43 = tt.splat %41 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %44 = arith.addi %43, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %45 = tt.splat %42 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %46 = arith.addi %45, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %47 = tt.expand_dims %44 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %48 = arith.muli %28, %47 : tensor<128x1xi32, #blocked> - %49 = tt.addptr %29, %48 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %50 = tt.expand_dims %46 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %51 = tt.broadcast %49 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> - %52 = tt.broadcast %50 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> - %53 = tt.addptr %51, %52 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> - %54 = arith.cmpi slt, %47, %30 : tensor<128x1xi32, #blocked> - %55 = arith.cmpi slt, %50, %31 : tensor<1x256xi32, #blocked> - %56 = tt.broadcast %54 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> - %57 = tt.broadcast %55 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> - %58 = arith.andi %56, %57 : tensor<128x256xi1, #blocked> - tt.store %53, %cst, %58 : tensor<128x256x!tt.ptr, #blocked> - } - } else { - %33 = arith.subi %12, %5 : i32 - %34 = arith.ceildivsi %33, %c132_i32 : i32 - %35 = arith.extsi %11 : i32 to i64 - %36 = arith.maxsi %35, %c1_i64 : i64 - %37 = arith.extsi %34 : i32 to i64 - %38 = arith.muli %37, %36 : i64 - %39 = arith.subi %5, %c132_i32 : i32 - %40 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - %41 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - %42 = arith.cmpi sgt, %38, %c0_i64 : i64 - %43 = arith.remsi %c0_i64, %36 : i64 - %44 = arith.cmpi eq, %43, %c0_i64 : i64 - %45:5 = scf.if %44 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32) { - %108 = arith.divsi %5, %13 : i32 - %109 = arith.muli %108, %c8_i32 : i32 - %110 = arith.subi %7, %109 : i32 - %111 = arith.minsi %110, %c8_i32 : i32 - %112 = arith.remsi %5, %111 : i32 - %113 = arith.addi %109, %112 : i32 - %114 = arith.remsi %5, %13 : i32 - %115 = arith.divsi %114, %111 : i32 - %116 = arith.muli %113, %c128_i32 : i32 - %117 = arith.muli %115, %c256_i32 : i32 - %118 = tt.splat %116 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %119 = tt.splat %116 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %120 = arith.addi %118, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %121 = arith.addi %119, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %122 = tt.splat %117 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %123 = tt.splat %117 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %124 = arith.addi %122, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %125 = arith.addi %123, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %126 = arith.cmpi slt, %120, %20 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %127 = arith.select %126, %120, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %128 = arith.cmpi slt, %124, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %129 = arith.select %128, %124, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %130 = tt.expand_dims %127 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %131 = arith.muli %130, %22 : tensor<128x1xi32, #blocked2> - %132 = tt.broadcast %131 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %133 = tt.expand_dims %129 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> - %134 = arith.muli %133, %24 : tensor<1x256xi32, #blocked1> - %135 = tt.broadcast %134 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - scf.yield %121, %125, %132, %135, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 - } else { - scf.yield %3, %2, %1, %0, %39 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 - } - %46 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> - %47 = tt.broadcast %46 : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %48 = arith.addi %45#2, %47 : tensor<128x64xi32, #blocked2> - %49 = tt.addptr %23, %48 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %50 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %51 = tt.broadcast %50 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - %52 = arith.addi %51, %45#3 : tensor<64x256xi32, #blocked1> - %53 = tt.addptr %25, %52 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> - %54 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked2> - %55 = arith.cmpi slt, %26, %54 : tensor<1x64xi32, #blocked2> - %56 = tt.broadcast %55 : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> - %57 = ttg.memdesc_subview %40[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> - %58 = tt.splat %42 : i1 -> tensor<128x64xi1, #blocked2> - %59 = arith.andi %58, %56 : tensor<128x64xi1, #blocked2> - %60 = ttg.async_copy_global_to_local %49, %57 mask %59 other %cst_2 : tensor<128x64x!tt.ptr, #blocked2> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> - %61 = ttg.async_commit_group %60 - %62 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked1> - %63 = arith.cmpi slt, %27, %62 : tensor<64x1xi32, #blocked1> - %64 = tt.broadcast %63 : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> - %65 = ttg.memdesc_subview %41[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> - %66 = tt.splat %42 : i1 -> tensor<64x256xi1, #blocked1> - %67 = arith.andi %66, %64 : tensor<64x256xi1, #blocked1> - %68 = ttg.async_copy_global_to_local %53, %65 mask %67 other %cst_3 : tensor<64x256x!tt.ptr, #blocked1> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> - %69 = ttg.async_commit_group %68 - %70 = arith.cmpi sgt, %38, %c1_i64 : i64 - %71 = arith.addi %43, %c1_i64 : i64 - %72 = arith.remsi %71, %36 : i64 - %73 = arith.cmpi eq, %72, %c0_i64 : i64 - %74:5 = scf.if %73 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32) { - %108 = arith.addi %45#4, %c132_i32 : i32 - %109 = arith.divsi %108, %13 : i32 - %110 = arith.muli %109, %c8_i32 : i32 - %111 = arith.subi %7, %110 : i32 - %112 = arith.minsi %111, %c8_i32 : i32 - %113 = arith.remsi %108, %112 : i32 - %114 = arith.addi %110, %113 : i32 - %115 = arith.remsi %108, %13 : i32 - %116 = arith.divsi %115, %112 : i32 - %117 = arith.muli %114, %c128_i32 : i32 - %118 = arith.muli %116, %c256_i32 : i32 - %119 = tt.splat %117 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %120 = tt.splat %117 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %121 = arith.addi %119, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %122 = arith.addi %120, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %123 = tt.splat %118 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %124 = tt.splat %118 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %125 = arith.addi %123, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %126 = arith.addi %124, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %127 = arith.cmpi slt, %121, %20 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %128 = arith.select %127, %121, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %129 = arith.cmpi slt, %125, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %130 = arith.select %129, %125, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %131 = tt.expand_dims %128 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %132 = arith.muli %131, %22 : tensor<128x1xi32, #blocked2> - %133 = tt.broadcast %132 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %134 = tt.expand_dims %130 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> - %135 = arith.muli %134, %24 : tensor<1x256xi32, #blocked1> - %136 = tt.broadcast %135 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - scf.yield %122, %126, %133, %136, %108 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 - } else { - scf.yield %45#0, %45#1, %45#2, %45#3, %45#4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 - } - %75 = arith.select %73, %c0_i32, %c1_i32 : i32 - %76 = arith.muli %75, %c64_i32 : i32 - %77 = tt.splat %76 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %78 = tt.splat %76 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %79 = arith.addi %77, %14 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %80 = arith.addi %78, %15 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %81 = tt.expand_dims %79 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> - %82 = tt.broadcast %81 : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %83 = arith.addi %74#2, %82 : tensor<128x64xi32, #blocked2> - %84 = tt.addptr %23, %83 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %85 = tt.expand_dims %80 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %86 = tt.broadcast %85 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - %87 = arith.addi %86, %74#3 : tensor<64x256xi32, #blocked1> - %88 = tt.addptr %25, %87 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> - %89 = arith.subi %arg5, %76 : i32 - %90 = tt.splat %89 : i32 -> tensor<1x64xi32, #blocked2> - %91 = arith.cmpi slt, %26, %90 : tensor<1x64xi32, #blocked2> - %92 = tt.broadcast %91 : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> - %93 = ttg.memdesc_subview %40[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> - %94 = tt.splat %70 : i1 -> tensor<128x64xi1, #blocked2> - %95 = arith.andi %94, %92 : tensor<128x64xi1, #blocked2> - %96 = ttg.async_copy_global_to_local %84, %93 mask %95 other %cst_2 : tensor<128x64x!tt.ptr, #blocked2> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> - %97 = ttg.async_commit_group %96 - %98 = tt.splat %89 : i32 -> tensor<64x1xi32, #blocked1> - %99 = arith.cmpi slt, %27, %98 : tensor<64x1xi32, #blocked1> - %100 = tt.broadcast %99 : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> - %101 = ttg.memdesc_subview %41[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> - %102 = tt.splat %70 : i1 -> tensor<64x256xi1, #blocked1> - %103 = arith.andi %102, %100 : tensor<64x256xi1, #blocked1> - %104 = ttg.async_copy_global_to_local %88, %101 mask %103 other %cst_3 : tensor<64x256x!tt.ptr, #blocked1> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> - %105 = ttg.async_commit_group %104 - %106:23 = scf.for %arg9 = %c0_i64 to %38 step %c1_i64 iter_args(%arg10 = %72, %arg11 = %74#4, %arg12 = %c1_i32, %arg13 = %4, %arg14 = %74#0, %arg15 = %74#1, %arg16 = %74#2, %arg17 = %74#3, %arg18 = %c1_i32, %arg19 = %c-1_i32, %arg20 = %44, %arg21 = %73, %arg22 = %61, %arg23 = %97, %arg24 = %69, %arg25 = %105, %arg26 = %75, %arg27 = %43, %arg28 = %72, %arg29 = %45#0, %arg30 = %74#0, %arg31 = %45#1, %arg32 = %74#1) -> (i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1, i1, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, i64, i64, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) : i64 { - %108 = arith.subi %38, %c2_i64 : i64 - %109 = arith.cmpi slt, %arg9, %108 : i64 - %110 = arith.addi %arg10, %c1_i64 : i64 - %111 = arith.remsi %110, %36 : i64 - %112 = arith.cmpi eq, %111, %c0_i64 : i64 - %113:5 = scf.if %112 -> (tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32) { - %160 = arith.addi %arg11, %c132_i32 : i32 - %161 = arith.divsi %160, %13 : i32 - %162 = arith.muli %161, %c8_i32 : i32 - %163 = arith.subi %7, %162 : i32 - %164 = arith.minsi %163, %c8_i32 : i32 - %165 = arith.remsi %160, %164 : i32 - %166 = arith.addi %162, %165 : i32 - %167 = arith.remsi %160, %13 : i32 - %168 = arith.divsi %167, %164 : i32 - %169 = arith.muli %166, %c128_i32 : i32 - %170 = arith.muli %168, %c256_i32 : i32 - %171 = tt.splat %169 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %172 = tt.splat %169 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %173 = arith.addi %171, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %174 = arith.addi %172, %17 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %175 = tt.splat %170 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %176 = tt.splat %170 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %177 = arith.addi %175, %18 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %178 = arith.addi %176, %19 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %179 = arith.cmpi slt, %173, %20 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %180 = arith.select %179, %173, %cst_1 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %181 = arith.cmpi slt, %177, %21 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %182 = arith.select %181, %177, %cst_0 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %183 = tt.expand_dims %180 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %184 = arith.muli %183, %22 : tensor<128x1xi32, #blocked2> - %185 = tt.broadcast %184 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %186 = tt.expand_dims %182 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> - %187 = arith.muli %186, %24 : tensor<1x256xi32, #blocked1> - %188 = tt.broadcast %187 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - scf.yield %174, %178, %185, %188, %160 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 - } else { - scf.yield %arg14, %arg15, %arg16, %arg17, %arg11 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32 - } - %114 = arith.addi %arg19, %c1_i32 : i32 - %115 = arith.cmpi slt, %114, %c3_i32 : i32 - %116 = arith.select %115, %114, %c0_i32 : i32 - %117 = arith.select %arg20, %cst_4, %arg13 : tensor<128x256xf32, #mma> - %118 = ttg.memdesc_subview %40[%116, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> - %119 = ttg.async_wait %arg24 {num = 2 : i32} - %120 = ttg.memdesc_subview %41[%116, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> - %121 = ttng.warp_group_dot %118, %120, %117 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> - %122:3 = ttng.warp_group_dot_wait %121, %118, %120 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> - %123 = arith.addi %arg26, %c1_i32 : i32 - %124 = arith.addi %arg18, %c1_i32 : i32 - %125 = arith.cmpi slt, %124, %c3_i32 : i32 - %126 = arith.select %125, %124, %c0_i32 : i32 - %127 = arith.select %112, %c0_i32, %123 : i32 - %128 = arith.muli %127, %c64_i32 : i32 - %129 = tt.splat %128 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %130 = tt.splat %128 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %131 = arith.addi %129, %14 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %132 = arith.addi %130, %15 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %133 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2> - %134 = tt.broadcast %133 : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2> - %135 = arith.addi %113#2, %134 : tensor<128x64xi32, #blocked2> - %136 = tt.addptr %23, %135 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> - %137 = tt.expand_dims %132 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %138 = tt.broadcast %137 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> - %139 = arith.addi %138, %113#3 : tensor<64x256xi32, #blocked1> - %140 = tt.addptr %25, %139 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> - %141 = arith.subi %arg5, %128 : i32 - %142 = tt.splat %141 : i32 -> tensor<1x64xi32, #blocked2> - %143 = arith.cmpi slt, %26, %142 : tensor<1x64xi32, #blocked2> - %144 = tt.broadcast %143 : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2> - %145 = ttg.memdesc_subview %40[%126, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 3x128x64> - %146 = tt.splat %109 : i1 -> tensor<128x64xi1, #blocked2> - %147 = arith.andi %146, %144 : tensor<128x64xi1, #blocked2> - %148 = ttg.async_copy_global_to_local %136, %145 mask %147 other %cst_2 : tensor<128x64x!tt.ptr, #blocked2> -> <128x64xf16, #shared, #smem, mutable, 3x128x64> - %149 = ttg.async_commit_group %148 - %150 = tt.splat %141 : i32 -> tensor<64x1xi32, #blocked1> - %151 = arith.cmpi slt, %27, %150 : tensor<64x1xi32, #blocked1> - %152 = tt.broadcast %151 : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> - %153 = ttg.memdesc_subview %41[%126, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable, 3x64x256> - %154 = tt.splat %109 : i1 -> tensor<64x256xi1, #blocked1> - %155 = arith.andi %154, %152 : tensor<64x256xi1, #blocked1> - %156 = ttg.async_copy_global_to_local %140, %153 mask %155 other %cst_3 : tensor<64x256x!tt.ptr, #blocked1> -> <64x256xf16, #shared1, #smem, mutable, 3x64x256> - %157 = ttg.async_commit_group %156 - %158 = arith.subi %36, %c1_i64 : i64 - %159 = arith.cmpi eq, %arg27, %158 : i64 - scf.if %159 { - %160 = tt.expand_dims %arg29 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %161 = arith.muli %28, %160 : tensor<128x1xi32, #blocked> - %162 = tt.addptr %29, %161 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %163 = tt.expand_dims %arg31 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %164 = tt.broadcast %162 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> - %165 = tt.broadcast %163 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> - %166 = tt.addptr %164, %165 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> - %167 = arith.cmpi slt, %160, %30 : tensor<128x1xi32, #blocked> - %168 = arith.cmpi slt, %163, %31 : tensor<1x256xi32, #blocked> - %169 = tt.broadcast %167 : tensor<128x1xi1, #blocked> -> tensor<128x256xi1, #blocked> - %170 = tt.broadcast %168 : tensor<1x256xi1, #blocked> -> tensor<128x256xi1, #blocked> - %171 = arith.andi %169, %170 : tensor<128x256xi1, #blocked> - %172 = arith.truncf %122#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %173 = ttg.convert_layout %172 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> - tt.store %166, %173, %171 : tensor<128x256x!tt.ptr, #blocked> - } - scf.yield %111, %113#4, %123, %122#0, %113#0, %113#1, %113#2, %113#3, %126, %116, %arg21, %112, %arg23, %149, %arg25, %157, %127, %arg28, %111, %arg30, %113#0, %arg32, %113#1 : i64, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<128x64xi32, #blocked2>, tensor<64x256xi32, #blocked1>, i32, i32, i1, i1, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, i64, i64, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - } - %107 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %40 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - ttg.local_dealloc %41 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - } - tt.return - } -} - From ef1133263576c11e11bbf4f55840bf10d04c8c1d Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 28 Jan 2025 01:54:32 -0500 Subject: [PATCH 08/32] check persistent matmul perf --- python/tutorials/09-persistent-matmul.py | 532 +---------------------- 1 file changed, 24 insertions(+), 508 deletions(-) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 32a7848a3182..260f3c2e3f65 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -152,136 +152,6 @@ def matmul(a, b): return c -@triton.jit(launch_metadata=_matmul_launch_metadata) -def matmul_kernel_persistent_fused(a_ptr, b_ptr, c_ptr, # - M, N, K, # - stride_am, stride_ak, # - stride_bk, stride_bn, # - stride_cm, stride_cn, # - BLOCK_SIZE_M: tl.constexpr, # - BLOCK_SIZE_N: tl.constexpr, # - BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, # - NUM_SMS: tl.constexpr, # - ): - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - k_tiles = tl.cdiv(K, BLOCK_SIZE_K) - num_tiles = num_pid_m * num_pid_n - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) - - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - pid_m = 0 - pid_n = 0 - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - start_m = pid_m * BLOCK_SIZE_M - start_n = pid_n * BLOCK_SIZE_N - offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) - offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) - offs_am = tl.where(offs_am < M, offs_am, 0) - offs_bn = tl.where(offs_bn < N, offs_bn, 0) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) - offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, b, accumulator) - - if ki == k_tiles - 1: - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if (c_ptr.dtype.element_ty == tl.float8e4nv): - c = accumulator.to(tl.float8e4nv) - else: - c = accumulator.to(tl.float16) - tl.store(c_ptrs, c, mask=c_mask) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - -def matmul_persistent_fused(a, b): - configs = { - torch.float8_e4m3fn: { - "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, - "num_warps": 8 - }, torch.float16: { - "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, - "num_warps": 8 - } - } - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.dtype == b.dtype, "Incompatible dtypes" - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - M, K = a.shape - K, N = b.shape - dtype = a.dtype - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=dtype) - # 1D launch kernel where each block gets its own program. - grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) - - matmul_kernel_persistent_fused[grid]( - a, b, c, # - M, N, K, # - a.stride(0), a.stride(1), # - b.stride(0), b.stride(1), # - c.stride(0), c.stride(1), # - BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # - BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # - BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # - GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # - NUM_SMS=NUM_SMS, # - num_stages=configs[dtype]["num_stages"], # - num_warps=configs[dtype]["num_warps"], # - ) - #kernel = matmul_kernel_persistent_fused.warmup( - # a, b, c, # - # M, N, K, # - # a.stride(0), a.stride(1), # - # b.stride(0), b.stride(1), # - # c.stride(0), c.stride(1), # - # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # - # BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # - # BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # - # GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # - # NUM_SMS=NUM_SMS, # - # num_stages=configs[dtype]["num_stages"], # - # num_warps=configs[dtype]["num_warps"], # - # grid=grid - #) - #print(kernel.asm["ttgir"]) - return c - - @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # M, N, K, # @@ -340,306 +210,6 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # tl.store(c_ptrs, c, mask=c_mask) -matmul_kernel_persistent_ttgir = """ -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c2_i32 = arith.constant 2 : i32 - %c3_i32 = arith.constant 3 : i32 - %c-1_i32 = arith.constant -1 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %cst_0 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> - %c64_i32 = arith.constant 64 : i32 - %c132_i32 = arith.constant 132 : i32 - %c127_i32 = arith.constant 127 : i32 - %c255_i32 = arith.constant 255 : i32 - %c63_i32 = arith.constant 63 : i32 - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %0 = tt.get_program_id x : i32 - %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %5 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %6 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %7 = arith.addi %arg3, %c127_i32 : i32 - %8 = arith.divsi %7, %c128_i32 : i32 - %9 = arith.addi %arg4, %c255_i32 : i32 - %10 = arith.divsi %9, %c256_i32 : i32 - %11 = arith.addi %arg5, %c63_i32 : i32 - %12 = arith.divsi %11, %c64_i32 : i32 - %13 = arith.muli %8, %10 : i32 - %14 = arith.muli %10, %c8_i32 : i32 - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %17 = arith.subi %13, %0 : i32 - %18 = arith.ceildivsi %17, %c132_i32 : i32 - %19 = arith.maxsi %12, %c1_i32 : i32 - %20 = arith.muli %18, %19 : i32 - %21 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - %22 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - %23 = arith.cmpi sgt, %20, %c0_i32 : i32 - %24 = arith.divsi %0, %14 : i32 - %25 = arith.muli %24, %c8_i32 : i32 - %26 = arith.subi %8, %25 : i32 - %27 = arith.minsi %26, %c8_i32 : i32 - %28 = arith.remsi %0, %27 : i32 - %29 = arith.addi %25, %28 : i32 - %30 = arith.remsi %0, %14 : i32 - %31 = arith.divsi %30, %27 : i32 - %32 = arith.muli %29, %c128_i32 : i32 - %33 = arith.muli %31, %c256_i32 : i32 - %34 = tt.splat %32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %35 = arith.addi %34, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %36 = tt.splat %33 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %37 = arith.addi %36, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %38 = arith.cmpi slt, %35, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %39 = arith.select %38, %35, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %40 = arith.cmpi slt, %37, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %41 = arith.select %40, %37, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %42 = tt.expand_dims %39 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %43 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> - %44 = arith.muli %42, %43 : tensor<128x1xi32, #blocked1> - %45 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %46 = tt.broadcast %44 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %47 = tt.broadcast %45 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %48 = arith.addi %46, %47 : tensor<128x64xi32, #blocked1> - %49 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> - %50 = tt.addptr %49, %48 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %51 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %52 = tt.expand_dims %41 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %53 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> - %54 = arith.muli %52, %53 : tensor<1x256xi32, #blocked> - %55 = tt.broadcast %51 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %56 = tt.broadcast %54 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %57 = arith.addi %55, %56 : tensor<64x256xi32, #blocked> - %58 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> - %59 = tt.addptr %58, %57 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %60 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> - %61 = arith.cmpi slt, %45, %60 : tensor<1x64xi32, #blocked1> - %62 = tt.broadcast %61 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %63 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %64 = tt.splat %23 : i1 -> tensor<128x64xi1, #blocked1> - %65 = arith.andi %64, %62 : tensor<128x64xi1, #blocked1> - %66 = ttg.async_copy_global_to_local %50, %63 mask %65 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %67 = ttg.async_commit_group %66 - %68 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> - %69 = arith.cmpi slt, %51, %68 : tensor<64x1xi32, #blocked> - %70 = tt.broadcast %69 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %71 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %72 = tt.splat %23 : i1 -> tensor<64x256xi1, #blocked> - %73 = arith.andi %72, %70 : tensor<64x256xi1, #blocked> - %74 = ttg.async_copy_global_to_local %59, %71 mask %73 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %75 = ttg.async_commit_group %74 - %76 = arith.cmpi sgt, %20, %c1_i32 : i32 - %77 = arith.remsi %c1_i32, %19 : i32 - %78 = arith.cmpi eq, %77, %c0_i32 : i32 - %79 = arith.cmpi ne, %77, %c0_i32 : i32 - %80 = arith.extui %79 : i1 to i32 - %81:5 = scf.if %78 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %121 = arith.addi %0, %c132_i32 : i32 - %122 = arith.divsi %121, %14 : i32 - %123 = arith.muli %122, %c8_i32 : i32 - %124 = arith.subi %8, %123 : i32 - %125 = arith.minsi %124, %c8_i32 : i32 - %126 = arith.remsi %121, %125 : i32 - %127 = arith.addi %123, %126 : i32 - %128 = arith.remsi %121, %14 : i32 - %129 = arith.divsi %128, %125 : i32 - %130 = arith.muli %127, %c128_i32 : i32 - %131 = arith.muli %129, %c256_i32 : i32 - %132 = tt.splat %130 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %133 = arith.addi %132, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %134 = tt.splat %131 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %135 = arith.addi %134, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %136 = arith.cmpi slt, %133, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %137 = arith.select %136, %133, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %138 = arith.cmpi slt, %135, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %139 = arith.select %138, %135, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %130, %131, %137, %139, %121 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } else { - scf.yield %32, %33, %39, %41, %0 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } - %82 = arith.muli %80, %c64_i32 : i32 - %83 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %84 = tt.splat %82 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %85 = arith.addi %83, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %86 = arith.addi %84, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %87 = tt.expand_dims %81#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %88 = arith.muli %87, %43 : tensor<128x1xi32, #blocked1> - %89 = tt.expand_dims %85 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %90 = tt.broadcast %88 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %91 = tt.broadcast %89 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %92 = arith.addi %90, %91 : tensor<128x64xi32, #blocked1> - %93 = tt.addptr %49, %92 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %94 = tt.expand_dims %86 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %95 = tt.expand_dims %81#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %96 = arith.muli %95, %53 : tensor<1x256xi32, #blocked> - %97 = tt.broadcast %94 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %98 = tt.broadcast %96 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %99 = arith.addi %97, %98 : tensor<64x256xi32, #blocked> - %100 = tt.addptr %58, %99 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %101 = arith.subi %arg5, %82 : i32 - %102 = tt.splat %101 : i32 -> tensor<1x64xi32, #blocked1> - %103 = arith.cmpi slt, %45, %102 : tensor<1x64xi32, #blocked1> - %104 = tt.broadcast %103 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %105 = ttg.memdesc_subview %21[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %106 = tt.splat %76 : i1 -> tensor<128x64xi1, #blocked1> - %107 = arith.andi %106, %104 : tensor<128x64xi1, #blocked1> - %108 = ttg.async_copy_global_to_local %93, %105 mask %107 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %109 = ttg.async_commit_group %108 - %110 = tt.splat %101 : i32 -> tensor<64x1xi32, #blocked> - %111 = arith.cmpi slt, %51, %110 : tensor<64x1xi32, #blocked> - %112 = tt.broadcast %111 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %113 = ttg.memdesc_subview %22[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %114 = tt.splat %76 : i1 -> tensor<64x256xi1, #blocked> - %115 = arith.andi %114, %112 : tensor<64x256xi1, #blocked> - %116 = ttg.async_copy_global_to_local %100, %113 mask %115 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %117 = ttg.async_commit_group %116 - %lol = arith.subi %12, %c1_i32 : i32 - %118:16 = scf.for %arg9 = %c0_i32 to %20 step %c1_i32 iter_args( - %arg10 = %81#4, %arg11 = %cst_3, %arg12 = %81#0, %arg13 = %81#1, - %arg14 = %81#2, %arg15 = %81#3, %arg16 = %c1_i32, %arg17 = %c-1_i32, - %arg18 = %80, %arg19 = %c0_i32, %arg21 = %75, %arg22 = %117, %arg23 = %32, %arg24 = %81#0, %arg25 = %33, %arg26 = %81#1) -> (i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32) : i32 { - %121 = arith.subi %20, %c2_i32 : i32 - %122 = arith.cmpi slt, %arg9, %121 : i32 - %rollover = arith.cmpi eq, %arg18, %lol : i32 - %123 = arith.addi %arg18, %c1_i32 : i32 - %126 = arith.select %rollover, %c0_i32, %123 : i32 - %125 = arith.cmpi eq, %126, %c0_i32 : i32 - %127:5 = scf.if %125 -> (i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32) { - %178 = arith.addi %arg10, %c132_i32 : i32 - %179 = arith.divsi %178, %14 : i32 - %180 = arith.muli %179, %c8_i32 : i32 - %181 = arith.subi %8, %180 : i32 - %182 = arith.minsi %181, %c8_i32 : i32 - %183 = arith.remsi %178, %182 : i32 - %184 = arith.addi %180, %183 : i32 - %185 = arith.remsi %178, %14 : i32 - %186 = arith.divsi %185, %182 : i32 - %187 = arith.muli %184, %c128_i32 : i32 - %188 = arith.muli %186, %c256_i32 : i32 - %189 = tt.splat %187 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %190 = arith.addi %189, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %191 = tt.splat %188 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %192 = arith.addi %191, %4 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %193 = arith.cmpi slt, %190, %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %194 = arith.select %193, %190, %cst_0 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - %195 = arith.cmpi slt, %192, %6 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %196 = arith.select %195, %192, %cst {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - scf.yield %187, %188, %194, %196, %178 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } else { - scf.yield %arg12, %arg13, %arg14, %arg15, %arg10 : i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32 - } - %128 = arith.addi %arg17, %c1_i32 : i32 - %129 = arith.cmpi slt, %128, %c3_i32 : i32 - %130 = arith.select %129, %128, %c0_i32 : i32 - %131 = arith.cmpi ne, %arg19, %c0_i32 : i32 - %132 = ttg.memdesc_subview %21[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %133 = ttg.async_wait %arg21 {num = 2 : i32} - %134 = ttg.memdesc_subview %22[%130, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %135 = ttng.warp_group_dot %132, %134, %arg11, %131 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> * !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf32, #mma> - %136:3 = ttng.warp_group_dot_wait %135, %132, %134 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %137 = arith.addi %arg16, %c1_i32 : i32 - %138 = arith.cmpi slt, %137, %c3_i32 : i32 - %139 = arith.select %138, %137, %c0_i32 : i32 - %140 = arith.muli %126, %c64_i32 : i32 - %141 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %142 = tt.splat %140 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %143 = arith.addi %141, %15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %144 = arith.addi %142, %16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %145 = tt.expand_dims %127#2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> - %146 = arith.muli %145, %43 : tensor<128x1xi32, #blocked1> - %147 = tt.expand_dims %143 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %148 = tt.broadcast %146 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %149 = tt.broadcast %147 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %150 = arith.addi %148, %149 : tensor<128x64xi32, #blocked1> - %151 = tt.addptr %49, %150 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %152 = tt.expand_dims %144 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %153 = tt.expand_dims %127#3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %154 = arith.muli %153, %53 : tensor<1x256xi32, #blocked> - %155 = tt.broadcast %152 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> - %156 = tt.broadcast %154 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> - %157 = arith.addi %155, %156 : tensor<64x256xi32, #blocked> - %158 = tt.addptr %58, %157 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> - %159 = arith.subi %arg5, %140 : i32 - %160 = tt.splat %159 : i32 -> tensor<1x64xi32, #blocked1> - %161 = arith.cmpi slt, %45, %160 : tensor<1x64xi32, #blocked1> - %162 = tt.broadcast %161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> - %163 = ttg.memdesc_subview %21[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> - %164 = tt.splat %122 : i1 -> tensor<128x64xi1, #blocked1> - %165 = arith.andi %164, %162 : tensor<128x64xi1, #blocked1> - %166 = ttg.async_copy_global_to_local %151, %163 mask %165 other %cst_1 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xf16, #shared, #smem, mutable> - %167 = ttg.async_commit_group %166 - %168 = tt.splat %159 : i32 -> tensor<64x1xi32, #blocked> - %169 = arith.cmpi slt, %51, %168 : tensor<64x1xi32, #blocked> - %170 = tt.broadcast %169 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> - %171 = ttg.memdesc_subview %22[%139, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %172 = tt.splat %122 : i1 -> tensor<64x256xi1, #blocked> - %173 = arith.andi %172, %170 : tensor<64x256xi1, #blocked> - %174 = ttg.async_copy_global_to_local %158, %171 mask %173 other %cst_2 : tensor<64x256x!tt.ptr, #blocked> -> <64x256xf16, #shared1, #smem, mutable> - %175 = ttg.async_commit_group %174 - %176 = arith.subi %19, %c1_i32 : i32 - %177 = arith.cmpi eq, %arg19, %176 : i32 - scf.if %177 { - %178:3 = ttng.warp_group_dot_wait %136#0, %132, %134 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> - %179 = tt.splat %arg23 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %180 = arith.addi %179, %1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> - %181 = tt.splat %arg25 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %182 = arith.addi %181, %2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> - %183 = tt.expand_dims %180 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> - %184 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> - %185 = arith.muli %184, %183 : tensor<128x1xi32, #blocked2> - %186 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked2> - %187 = tt.addptr %186, %185 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> - %188 = tt.expand_dims %182 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %189 = tt.broadcast %187 : tensor<128x1x!tt.ptr, #blocked2> -> tensor<128x256x!tt.ptr, #blocked2> - %190 = tt.broadcast %188 : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> - %191 = tt.addptr %189, %190 : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> - %192 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked2> - %193 = arith.cmpi slt, %183, %192 : tensor<128x1xi32, #blocked2> - %194 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked2> - %195 = arith.cmpi slt, %188, %194 : tensor<1x256xi32, #blocked2> - %196 = tt.broadcast %193 : tensor<128x1xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %197 = tt.broadcast %195 : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> - %198 = arith.andi %196, %197 : tensor<128x256xi1, #blocked2> - %199 = arith.truncf %178#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %200 = ttg.convert_layout %199 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked2> - tt.store %191, %200, %198 : tensor<128x256x!tt.ptr, #blocked2> - } - scf.yield %127#4, %136#0, %127#0, %127#1, - %127#2, %127#3, %139, %130, - %126, %arg18, %arg22, %175, %arg24, %127#0, %arg26, %127#1 : i32, tensor<128x256xf32, #mma>, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32 - } - %119 = ttng.warp_group_dot_wait %118#1 {pendings = 0 : i32} : tensor<128x256xf32, #mma> - %120 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %21 : !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> - ttg.local_dealloc %22 : !ttg.memdesc<3x64x256xf16, #shared1, #smem, mutable> - tt.return - } -} - - -""" - -file = pathlib.Path("matmul_kernel_persistent.ttgir") -file.write_text(matmul_kernel_persistent_ttgir) -matmul_kernel_persistent_precompiled = triton.compile(str(file)) - - def matmul_persistent(a, b): configs = { torch.float8_e4m3fn: { @@ -662,23 +232,6 @@ def matmul_persistent(a, b): # 1D launch kernel where each block gets its own program. grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) - #assert a.stride(1) == 1 and b.stride(0) == 1 and c.stride(1) == 1 - #bytes_per_elem = a.element_size() - #flops_str = f"flops{bytes_per_elem * 8}" - #with proton.scope(f"precompiled [M={M}, N={N}, K={K}]", - # {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): - # matmul_kernel_persistent_precompiled[(grid(configs[torch.float16])[0], 1, 1)]( - # a, - # b, - # c, # - # M, - # N, - # K, # - # a.stride(0), - # b.stride(1), # - # c.stride(0), - # ) - matmul_kernel_persistent[grid]( a, b, c, # M, N, K, # @@ -693,23 +246,6 @@ def matmul_persistent(a, b): num_stages=configs[dtype]["num_stages"], # num_warps=configs[dtype]["num_warps"], # ) - - #kernel = matmul_kernel_persistent.warmup( - # a, b, c, # - # M, N, K, # - # a.stride(0), a.stride(1), # - # b.stride(0), b.stride(1), # - # c.stride(0), c.stride(1), # - # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # - # BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # - # BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # - # GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # - # NUM_SMS=NUM_SMS, # - # num_stages=configs[dtype]["num_stages"], # - # num_warps=configs[dtype]["num_warps"], # - # grid=grid - #) - #print(kernel.asm["ttgir"]) return c @@ -818,6 +354,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n + num_pid_in_group = GROUP_SIZE_M * num_pid_n a_desc = tl._experimental_make_tensor_descriptor( a_ptr, @@ -838,48 +375,27 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], ) - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M - offs_bn = pid_n * BLOCK_SIZE_N + for tile_id in range(start_pid, num_tiles, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m - offs_k = ki * BLOCK_SIZE_K + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N - a = a_desc.load([offs_am, offs_k]) - b = b_desc.load([offs_bn, offs_k]) - accumulator = tl.dot(a, b.T, accumulator) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K - if ki == k_tiles - 1: - c = accumulator.to(dtype) + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) - c_desc.store([offs_am, offs_bn], c) + c = accumulator.to(dtype) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + c_desc.store([offs_am, offs_bn], c) def matmul_descriptor_persistent(a, b): @@ -976,16 +492,16 @@ def bench(K, dtype, reps=1000, warmup_reps=10000): b = b.T.contiguous() - #if cublas is not None: - # bench_fn(reps, warmup_reps, cublas_matmul, a, b) - #if dtype == torch.float16: - # bench_fn(reps, warmup_reps, torch_matmul, a, b) + if cublas is not None: + bench_fn(reps, warmup_reps, cublas_matmul, a, b) + if dtype == torch.float16: + bench_fn(reps, warmup_reps, torch_matmul, a, b) bench_fn(reps, warmup_reps, matmul, a, b.T) bench_fn(reps, warmup_reps, matmul_persistent_fused, a, b.T) bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) - #if supports_tma(): - # bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) - # bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b) + if supports_tma(): + bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) + bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b) def validate(M, N, K, dtype): From e422788fa758b25469e97c80be4173c8d8110857 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 28 Jan 2025 01:54:55 -0500 Subject: [PATCH 09/32] remove unused include --- python/tutorials/09-persistent-matmul.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 260f3c2e3f65..d080cc795d10 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -26,7 +26,6 @@ import triton.language as tl import triton.tools.experimental_descriptor import triton.profiler as proton -import pathlib from contextlib import contextmanager from typing import Optional From daca5569da85943684dc79e3bc8a5fd6d84da7ef Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 28 Jan 2025 01:57:34 -0500 Subject: [PATCH 10/32] remove invalid bench --- python/tutorials/09-persistent-matmul.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index d080cc795d10..981cf71128e9 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -496,7 +496,6 @@ def bench(K, dtype, reps=1000, warmup_reps=10000): if dtype == torch.float16: bench_fn(reps, warmup_reps, torch_matmul, a, b) bench_fn(reps, warmup_reps, matmul, a, b.T) - bench_fn(reps, warmup_reps, matmul_persistent_fused, a, b.T) bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) From 262383d22d8a51f52cf9c0af2e62e723baa13828 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 28 Jan 2025 12:00:28 -0500 Subject: [PATCH 11/32] fix conflict --- lib/Analysis/AxisInfo.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 01cc42b02ca9..d51f28273d77 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1035,13 +1035,8 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); visitors.append(); -<<<<<<< HEAD - visitors.append, - ConstantOpAxisInfoVisitor>(); visitors.append(); -======= visitors.append(); ->>>>>>> origin/main visitors.append, AddSubOpAxisInfoVisitor, AddSubOpAxisInfoVisitor>(); From b423b45601fbbd958dacab73c35e0aa77744b143 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 28 Jan 2025 12:54:10 -0500 Subject: [PATCH 12/32] fix lit tests --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 11 +- test/TritonGPU/fuse-nested-loops.mlir | 174 +++++++++--------- third_party/nvidia/backend/compiler.py | 3 +- 3 files changed, 97 insertions(+), 91 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index 3f5d8a9461a1..5d2ad8eaa4e6 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -20,6 +20,7 @@ namespace gpu { #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute"; +static constexpr llvm::StringLiteral kAlwaysFuseAttrName = "ttg.always-fuse"; namespace { struct FuseNestedLoopsPass @@ -234,7 +235,7 @@ static Logue createLogueFrom(llvm::iterator_range ops, // recursively. static bool canHoistLoopBoundComputation(Operation *op) { auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); }; - return isPure(op) && op->hasTrait() && + return isMemoryEffectFree(op) && llvm::all_of(op->getOperandTypes(), isScalar) && llvm::all_of(op->getResultTypes(), isScalar); } @@ -794,7 +795,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { llvm::append_range(outerOuts, logueIf.getResults().slice(1, logue.getNumOutputs())); } - llvm::append_range(outerOuts, epilogue.getOutputs()); + llvm::append_range(outerOuts, epilogueIf.getResults()); b.setInsertionPointToEnd(fused.getBody()); b.create(outerOuts); @@ -825,6 +826,9 @@ static void flattenLoopNest(LoopNestNode *node, mlir::DominanceInfo &domInfo) { // Fuse simple loop nests with a single outer and inner loop, and where the // inner loop has a `tt.dot` operation. static bool shouldFuse(const LoopNest &nest) { + if (nest.root->loop->hasAttr(kAlwaysFuseAttrName)) + return true; + if (nest.nodes.size() != 2 || nest.root->children.size() != 1) return false; @@ -1024,7 +1028,8 @@ void FuseNestedLoopsPass::runOnOperation() { for (LoopNest &nest : nests) { if (!shouldFuse(nest)) continue; - if (failed(speculateInnerLoopLength(nest, domInfo))) + if (!nest.root->loop->hasAttr(kAlwaysFuseAttrName) && + failed(speculateInnerLoopLength(nest, domInfo))) continue; flattenLoopNest(nest.root, domInfo); } diff --git a/test/TritonGPU/fuse-nested-loops.mlir b/test/TritonGPU/fuse-nested-loops.mlir index b4a0a5bd8942..59eecf055073 100644 --- a/test/TritonGPU/fuse-nested-loops.mlir +++ b/test/TritonGPU/fuse-nested-loops.mlir @@ -17,7 +17,7 @@ tt.func @no_fusion(%lb: index, %ub: index, %step: index) -> index { // CHECK-NEXT: yield scf.yield %1 : index // CHECK-NEXT: } - } + } {"ttg.always-fuse"} // CHECK-NEXT: after.loop "after.loop"() : () -> () tt.return %0 : index @@ -48,30 +48,34 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]] // T = -1 - // i = lbi + // i = lbi - stepi // j = None // for _ in range(total_iters): // - // CHECK: [[UNDEF_I64:%.*]] = ub.poison : i64 + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] // CHECK: scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( - // CHECK-SAME: [[T_ARG:%.*]] = %c-1_i64, [[I:%.*]] = [[LBI]], [[J_ARG:%.*]] = [[UNDEF_I64]]) -> (i64, i64, i64) : i64 { + // CHECK-SAME: [[T_ARG:%.*]] = %c-1_i64, [[I_ARG:%.*]] = [[I_INIT]], [[J_ARG:%.*]] = %c0_i64) -> (i64, i64, i64) : i64 { scf.for %i = %lbi to %ubi step %stepi : i64 { - // T = (T + 1) % inner_len + // T = 0 if T == (inner_len - 1) else T + 1 // // CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64 - // CHECK-NEXT: [[T:%.*]] = arith.remsi [[T_PLUS_1]], [[INNER_LEN]] + // CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 + // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]] + // CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]] // if T == 0: + // i += stepi // prologue(i) // j = lbj // // CHECK: [[START:%.*]] = arith.subi %c0_i64, %c0_i64 : i64 // CHECK-NEXT: [[PROLOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[START]] - // CHECK-NEXT: [[J:%.*]] = scf.if [[PROLOGUE_COND]] -> (i64) { + // CHECK-NEXT: [[JI:%.*]]:2 = scf.if [[PROLOGUE_COND]] -> (i64, i64) { + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] // CHECK-NEXT: "prologue"([[I]]) : (i64) -> () - // CHECK-NEXT: yield [[LBJ]] + // CHECK-NEXT: yield [[LBJ]], [[I]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[J_ARG]] + // CHECK-NEXT: yield [[J_ARG]], [[I_ARG]] // CHECK-NEXT: } "prologue"(%i) : (i64) -> () @@ -84,11 +88,11 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK-NEXT: [[LT:%.*]] = arith.cmpi slt, [[T]], [[END]] // CHECK-NEXT: [[COND:%.*]] = arith.andi [[GE]], [[LT]] // CHECK-NEXT: [[J_NEXT:%.*]] = scf.if [[COND]] -> (i64) { - // CHECK-NEXT: "body"([[I]], [[J]]) : (i64, i64) -> () - // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[J]], [[STEPJ]] + // CHECK-NEXT: "body"([[JI]]#1, [[JI]]#0) : (i64, i64) -> () + // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[JI]]#0, [[STEPJ]] // CHECK-NEXT: yield [[J_INCR]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[J]] + // CHECK-NEXT: yield [[JI]]#0 // CHECK-NEXT: } scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () @@ -98,19 +102,15 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // epilogue(i) // i += stepi // - // CHECK: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 // CHECK-NEXT: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]] - // CHECK-NEXT: [[I_NEXT:%.*]] = scf.if [[EPILOGUE_COND]] -> (i64) { - // CHECK-NEXT: "epilogue"([[I]]) : (i64) -> () - // CHECK-NEXT: [[I_INCR:%.*]] = arith.addi [[I]], [[STEPI]] - // CHECK-NEXT: yield [[I_INCR]] + // CHECK-NEXT: scf.if [[EPILOGUE_COND]] { + // CHECK-NEXT: "epilogue"([[JI]]#1) : (i64) -> () // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[I]] // CHECK-NEXT: } "epilogue"(%i) : (i64) -> () - // CHECK-NEXT: yield [[T]], [[I_NEXT]], [[J_NEXT]] : i64, i64, i64 - } + // CHECK-NEXT: yield [[T]], [[JI]]#1, [[J_NEXT]] : i64, i64, i64 + } {"ttg.always-fuse"} tt.return } @@ -118,32 +118,34 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64, [[LBJ:%.*]]: i64, [[UBJ:%.*]]: i64, [[STEPJ:%.*]]: i64 // CHECK-SAME: [[INOUT:%.*]]: index tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64, %inout: index) -> index { - // CHECK-DAG: [[UNDEF_I64:%.*]] = ub.poison : i64 - // CHECK-DAG: [[UNDEF_INDEX:%.*]] = ub.poison : index + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] // CHECK: [[OUTER_OUTS:%.*]]:7 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS:%.*]] step %c1_i64 iter_args( // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, - // CHECK-SAME: [[I:%arg[0-9]+]] = [[LBI]] + // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]] // CHECK-SAME: [[M:%arg[0-9]+]] = [[INOUT]] - // CHECK-SAME: [[J_ARG:%arg[0-9]+]] = [[UNDEF_I64]] - // CHECK-SAME: [[K_ARG:%arg[0-9]+]] = [[UNDEF_INDEX]] - // CHECK-SAME: [[PROLOGUE_OUT_ARG:%arg[0-9]+]] = [[UNDEF_INDEX]] - // CHECK-SAME: [[EPILOGUE_OUT_ARG:%arg[0-9]+]] = [[UNDEF_INDEX]] + // CHECK-SAME: [[J_ARG:%arg[0-9]+]] = %c0_i64 + // CHECK-SAME: [[K_ARG:%arg[0-9]+]] = %c0 + // CHECK-SAME: [[PROLOGUE_OUT_ARG:%arg[0-9]+]] = %c0 + // CHECK-SAME: [[EPILOGUE_OUT_ARG:%arg[0-9]+]] = %c0 // CHECK-SAME: ) -> (i64, i64, index, i64, index, index, index) : i64 { %outer_out = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %inout) -> index : i64 { // if T == 0: + // i += stepi // prologue(i) // j = lbj // - // CHECK: [[PROLOGUE_OUTS:%.*]]:3 = scf.if %{{[0-9]+}} -> (i64, index, index) { + // CHECK: [[PROLOGUE_OUTS:%.*]]:4 = scf.if %{{[0-9]+}} -> (i64, index, index, i64) { + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] // CHECK-NEXT: [[PROLOGUE_RES:%.*]] = "prologue"([[I]], [[INOUT]], [[M]]) : (i64, index, index) -> index - // CHECK-NEXT: yield [[LBJ]], [[PROLOGUE_RES]], [[M]] + // CHECK-NEXT: yield [[LBJ]], [[PROLOGUE_RES]], [[M]], [[I]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[J_ARG]], [[PROLOGUE_OUT_ARG]], [[K_ARG]] + // CHECK-NEXT: yield [[J_ARG]], [[PROLOGUE_OUT_ARG]], [[K_ARG]], [[I_ARG]] // CHECK-NEXT: } // // J := [[PROLOGUE_OUTS]]#0 // PROLOGUE_OUT := [[PROLOGUE_OUTS]]#1 // K := [[PROLOGUE_OUTS]]#2 + // I := [[PROLOGUE_OUTS]]#3 %prologue_out = "prologue"(%i, %inout, %m) : (i64, index, index) -> index // if T >= 0 and T < len_j: @@ -151,7 +153,7 @@ tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // j += stepj // // CHECK: [[BODY_OUTS:%.*]]:2 = scf.if {{.*}} -> (i64, index) { - // CHECK-NEXT: [[BODY_OUT:%.*]] = "body"([[I]], [[PROLOGUE_OUTS]]#0, [[PROLOGUE_OUTS]]#2, [[PROLOGUE_OUTS]]#1, [[M]]) : (i64, i64, index, index, index) -> index + // CHECK-NEXT: [[BODY_OUT:%.*]] = "body"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#0, [[PROLOGUE_OUTS]]#2, [[PROLOGUE_OUTS]]#1, [[M]]) : (i64, i64, index, index, index) -> index // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[PROLOGUE_OUTS]]#0, [[STEPJ]] // CHECK-NEXT: yield [[J_INCR]], [[BODY_OUT]] // CHECK-NEXT: } else { @@ -166,18 +168,17 @@ tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // epilogue(i) // i += stepi // - // CHECK: [[EPILOGUE_OUTS:%.*]]:2 = scf.if {{.*}} -> (i64, index) { - // CHECK-NEXT: [[EPILOGUE_OUT:%.*]] = "epilogue"([[I]], [[PROLOGUE_OUTS]]#1, [[BODY_OUTS]]#1, [[M]]) : (i64, index, index, index) -> index - // CHECK-NEXT: [[I_INCR:%.*]] = arith.addi [[I]], [[STEPI]] - // CHECK-NEXT: yield [[I_INCR]], [[EPILOGUE_OUT]] + // CHECK: [[EPILOGUE_OUTS:%.*]] = scf.if {{.*}} -> (index) { + // CHECK-NEXT: [[EPILOGUE_OUT:%.*]] = "epilogue"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#1, [[BODY_OUTS]]#1, [[M]]) : (i64, index, index, index) -> index + // CHECK-NEXT: yield [[EPILOGUE_OUT]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[I]], [[EPILOGUE_OUT_ARG]] + // CHECK-NEXT: yield [[EPILOGUE_OUT_ARG]] // CHECK-NEXT: } %epilogue_out = "epilogue"(%i, %prologue_out, %inner_out, %m) : (i64, index, index, index) -> index - // CHECK-NEXT: yield %{{.*}}, [[EPILOGUE_OUTS]]#0, [[EPILOGUE_OUTS]]#1, [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1, [[EPILOGUE_OUTS]]#1 : i64, i64, index, i64, index, index, index + // CHECK-NEXT: yield %{{.*}}, [[PROLOGUE_OUTS]]#3, [[EPILOGUE_OUTS]], [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1, [[EPILOGUE_OUTS]] : i64, i64, index, i64, index, index, index scf.yield %epilogue_out : index - } + } {"ttg.always-fuse"} // CHECK: return [[OUTER_OUTS]]#2 tt.return %outer_out : index } @@ -213,34 +214,36 @@ tt.func @multiple_loops( // CHECK: [[INNER_LEN:%.*]] = arith.subi [[PLEN3]], %c2_i64 // CHECK-NEXT: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]] - // CHECK: [[UNDEF_I64:%.*]] = ub.poison : i64 - // CHECK: [[UNDEF_F32:%.*]] = ub.poison : f32 + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] // CHECK: [[OUTS:%.*]]:13 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, - // CHECK-SAME: [[I:%arg[0-9]+]] = [[LBI]], + // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]], // CHECK-SAME: [[M:%arg[0-9]+]] = [[M0]], - // CHECK-SAME: [[J0_ARG:%arg[0-9]+]] = [[UNDEF_I64]], - // CHECK-SAME: [[J1_ARG:%arg[0-9]+]] = [[UNDEF_I64]], - // CHECK-SAME: [[J2_ARG:%arg[0-9]+]] = [[UNDEF_I64]], - // CHECK-SAME: [[BODY0_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[BODY1_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[BODY2_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[PROLOGUE0_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[PROLOGUE1_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[EPILOGUE_ARG:%arg[0-9]+]] = [[UNDEF_F32]]) + // CHECK-SAME: [[J0_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[J1_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[J2_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[BODY0_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[BODY1_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[BODY2_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE0_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE1_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[EPILOGUE_ARG:%arg[0-9]+]] = %cst) %mN = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %m0) -> f32 : i64 { // CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64 - // CHECK-NEXT: [[T:%.*]] = arith.remsi [[T_PLUS_1]], [[INNER_LEN]] + // CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 + // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]] + // CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]] // CHECK: [[START0:%.*]] = arith.subi [[PLEN0]], %c0_i64 // CHECK-NEXT: [[PROLOGUE_COND0:%.*]] = arith.cmpi eq, [[T]], [[START0]] - // CHECK-NEXT: [[PROLOGUE0_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND0]] + // CHECK-NEXT: [[PROLOGUE0_OUTS:%.*]]:4 = scf.if [[PROLOGUE_COND0]] + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] // CHECK-NEXT: [[RES:%.*]] = "prologue0"([[I]], [[M]]) - // CHECK-NEXT: yield [[LBJ0]], [[RES]], [[RES]] + // CHECK-NEXT: yield [[LBJ0]], [[RES]], [[RES]], [[I]] // CHECK-NEXT: else - // CHECK-NEXT: yield [[J0_ARG]], [[PROLOGUE0_ARG]], [[BODY0_ARG]] + // CHECK-NEXT: yield [[J0_ARG]], [[PROLOGUE0_ARG]], [[BODY0_ARG]], [[I_ARG]] %k00 = "prologue0"(%i, %m) : (i64, f32) -> f32 // CHECK: [[END0:%.*]] = arith.addi [[START0]], [[LEN_J0]] @@ -248,7 +251,7 @@ tt.func @multiple_loops( // CHECK-NEXT: [[LT0:%.*]] = arith.cmpi slt, [[T]], [[END0]] // CHECK-NEXT: [[BODY_COND0:%.*]] = arith.andi [[GE0]], [[LT0]] // CHECK-NEXT: [[BODY0_OUTS:%.*]]:2 = scf.if [[BODY_COND0]] - // CHECK-NEXT: [[RES:%.*]] = "body0"([[I]], [[PROLOGUE0_OUTS]]#0, [[PROLOGUE0_OUTS]]#2) + // CHECK-NEXT: [[RES:%.*]] = "body0"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE0_OUTS]]#0, [[PROLOGUE0_OUTS]]#2) // CHECK-NEXT: [[NEXT_J0:%.*]] = arith.addi [[PROLOGUE0_OUTS]]#0, [[STEPJ0]] // CHECK-NEXT: yield [[NEXT_J0]], [[RES]] // CHECK-NEXT: else @@ -261,7 +264,7 @@ tt.func @multiple_loops( // CHECK: [[START1:%.*]] = arith.subi [[PLEN1]], %c1_i64 // CHECK-NEXT: [[PROLOGUE_COND1:%.*]] = arith.cmpi eq, [[T]], [[START1]] // CHECK-NEXT: [[PROLOGUE1_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND1]] - // CHECK-NEXT: [[RES:%.*]] = "prologue1"([[I]], [[BODY0_OUTS]]#1) + // CHECK-NEXT: [[RES:%.*]] = "prologue1"([[PROLOGUE0_OUTS]]#3, [[BODY0_OUTS]]#1) // CHECK-NEXT: yield [[LBJ1]], [[RES]], [[RES]] // CHECK-NEXT: else // CHECK-NEXT: yield [[J1_ARG]], [[PROLOGUE1_ARG]], [[BODY1_ARG]] @@ -272,7 +275,7 @@ tt.func @multiple_loops( // CHECK-NEXT: [[LT1:%.*]] = arith.cmpi slt, [[T]], [[END1]] // CHECK-NEXT: [[BODY_COND1:%.*]] = arith.andi [[GE1]], [[LT1]] // CHECK-NEXT: [[BODY1_OUTS:%.*]]:2 = scf.if [[BODY_COND1]] - // CHECK-NEXT: [[RES:%.*]] = "body1"([[I]], [[PROLOGUE1_OUTS]]#0, [[PROLOGUE1_OUTS]]#2) + // CHECK-NEXT: [[RES:%.*]] = "body1"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE1_OUTS]]#0, [[PROLOGUE1_OUTS]]#2) // CHECK-NEXT: [[NEXT_J1:%.*]] = arith.addi [[PROLOGUE1_OUTS]]#0, [[STEPJ1]] // CHECK-NEXT: yield [[NEXT_J1]], [[RES]] // CHECK-NEXT: else @@ -285,7 +288,7 @@ tt.func @multiple_loops( // CHECK: [[START2:%.*]] = arith.subi [[PLEN2]], %c2_i64 // CHECK-NEXT: [[PROLOGUE_COND2:%.*]] = arith.cmpi eq, [[T]], [[START2]] // CHECK-NEXT: [[PROLOGUE2_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND2]] - // CHECK-NEXT: [[RES:%.*]] = "prologue2"([[I]], [[BODY1_OUTS]]#1) + // CHECK-NEXT: [[RES:%.*]] = "prologue2"([[PROLOGUE0_OUTS]]#3, [[BODY1_OUTS]]#1) // CHECK-NEXT: yield [[LBJ2]], [[RES]], [[RES]] // CHECK-NEXT: else // CHECK-NEXT: yield [[J2_ARG]], [[PROLOGUE2_ARG]], [[BODY2_ARG]] @@ -296,7 +299,7 @@ tt.func @multiple_loops( // CHECK-NEXT: [[LT2:%.*]] = arith.cmpi slt, [[T]], [[END2]] // CHECK-NEXT: [[BODY_COND2:%.*]] = arith.andi [[GE2]], [[LT2]] // CHECK-NEXT: [[BODY2_OUTS:%.*]]:2 = scf.if [[BODY_COND2]] - // CHECK-NEXT: [[RES:%.*]] = "body2"([[I]], [[PROLOGUE2_OUTS]]#0, [[PROLOGUE2_OUTS]]#2) + // CHECK-NEXT: [[RES:%.*]] = "body2"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE2_OUTS]]#0, [[PROLOGUE2_OUTS]]#2) // CHECK-NEXT: [[NEXT_J2:%.*]] = arith.addi [[PROLOGUE2_OUTS]]#0, [[STEPJ2]] // CHECK-NEXT: yield [[NEXT_J2]], [[RES]] // CHECK-NEXT: else @@ -306,21 +309,19 @@ tt.func @multiple_loops( scf.yield %res : f32 } - // CHECK: [[END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 - // CHECK-NEXT: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[END]] - // CHECK-NEXT: [[EPILOGUE_OUTS:%.*]]:2 = scf.if [[EPILOGUE_COND]] - // CHECK-NEXT: [[RES:%.*]] = "epilogue"([[I]], [[BODY2_OUTS]]#1) - // CHECK-NEXT: [[I_INCR:%.*]] = arith.addi [[I]], [[STEPI]] - // CHECK-NEXT: yield [[I_INCR]], [[RES]] + // CHECK: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]] + // CHECK-NEXT: [[EPILOGUE_OUTS:%.*]] = scf.if [[EPILOGUE_COND]] + // CHECK-NEXT: [[RES:%.*]] = "epilogue"([[PROLOGUE0_OUTS]]#3, [[BODY2_OUTS]]#1) + // CHECK-NEXT: yield [[RES]] // CHECK-NEXT: else - // CHECK-NEXT: yield [[I]], [[EPILOGUE_ARG]] + // CHECK-NEXT: yield [[EPILOGUE_ARG]] %out = "epilogue"(%i, %k2N) : (i64, f32) -> f32 - // CHECK: scf.yield [[T]], [[EPILOGUE_OUTS]]#0, [[EPILOGUE_OUTS]]#1, + // CHECK: scf.yield [[T]], [[PROLOGUE0_OUTS]]#3, [[EPILOGUE_OUTS]], // CHECK-SAME: [[BODY0_OUTS]]#0, [[BODY1_OUTS]]#0, [[BODY2_OUTS]]#0, - // CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1, [[EPILOGUE_OUTS]]#1 : + // CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1, [[EPILOGUE_OUTS]] : scf.yield %out : f32 - } + } {"ttg.always-fuse"} // CHECK: return [[OUTS]]#2 tt.return %mN : f32 } @@ -332,12 +333,12 @@ tt.func @two_loop_nests(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} scf.for %i = %lbi to %ubi step %stepi : i64 { scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} // CHECK-NOT: scf.for // CHECK: tt.return tt.return @@ -360,16 +361,16 @@ tt.func @hoist_loop_bound_computations(%lbi: i64, %ubi: i64, %stepi: i64) { %lbj = arith.addi %lbi, %stepi : i64 %ubj = arith.addi %ubi, %stepi : i64 %stepj = arith.addi %stepi, %stepi : i64 - // CHECK: [[J:%.*]] = scf.if - // CHECK-NEXT: yield [[LBJ]] + // CHECK: [[J:%.*]]:2 = scf.if + // CHECK: yield [[LBJ]] // CHECK: scf.if // CHECK-NEXT: "body" - // CHECK-NEXT: arith.addi [[J]], [[STEPJ]] + // CHECK-NEXT: arith.addi [[J]]#0, [[STEPJ]] scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} tt.return } @@ -383,25 +384,24 @@ tt.func @cannot_fuse(%lbi: i64, %ubi: i64, %stepi: i64) { scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} tt.return } // CHECK-LABEL: @upcast_i16_to_i32 -// CHECK-SAME: [[LBI:%.*]]: i16, [[UBI:%.*]]: i16, [[STEPI:%.*]]: i16, [[LBJ:%.*]]: i16, [[UBJ:%.*]]: i16, [[STEPJ:%.*]]: i16 -tt.func @upcast_i16_to_i32(%lbi: i16, %ubi: i16, %stepi: i16, %lbj: i16, %ubj: i16, %stepj: i16) { - // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : i16 - // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : i16 +// CHECK-SAME: [[LBI:%.*]]: i32, [[UBI:%.*]]: i32, [[STEPI:%.*]]: i32, [[LBJ:%.*]]: i16, [[UBJ:%.*]]: i16, [[STEPJ:%.*]]: i16 +tt.func @upcast_i16_to_i32(%lbi: i32, %ubi: i32, %stepi: i32, %lbj: i16, %ubj: i16, %stepj: i16) { + // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : i32 + // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : i32 // CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] : i16 // CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] : i16 // CHECK: arith.extsi [[LEN_J]] : i16 to i32 - // CHECK: arith.extsi [[LEN_I]] : i16 to i32 - scf.for %i = %lbi to %ubi step %stepi : i16 { + scf.for %i = %lbi to %ubi step %stepi : i32 { scf.for %j = %lbj to %ubj step %stepj : i16 { - "body"(%i, %j) : (i16, i16) -> () + "body"(%i, %j) : (i32, i16) -> () } - } + } {"ttg.always-fuse"} tt.return } @@ -419,7 +419,7 @@ tt.func @upcast_index_to_i64(%lbi: index, %ubi: index, %stepi: index, %lbj: inde scf.for %j = %lbj to %ubj step %stepj { "body"(%i, %j) : (index, index) -> () } - } + } {"ttg.always-fuse"} tt.return } @@ -435,7 +435,7 @@ tt.func @triple_loop_nest( "body"(%i, %j, %k) : (i64, i64, i64) -> () } } - } + } {"ttg.always-fuse"} // CHECK-NOT: scf.for // CHECK: tt.return tt.return diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 4dfeab265a05..245d97c21069 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -252,8 +252,9 @@ def make_ttgir(mod, metadata, opt, capability): passes.common.add_cse(pm) if capability // 10 >= 8: passes.ttgpuir.add_fuse_nested_loops(pm) - passes.common.add_licm(pm) passes.common.add_canonicalizer(pm) + passes.common.add_licm(pm) # run LICM after loop nest fusion + if capability // 10 >= 8: passes.ttgpuir.add_optimize_accumulator_init(pm) passes.common.add_canonicalizer(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) From 89a1ab97fe76a647e2b84cb9a5da230a8aa79226 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 28 Jan 2025 16:52:20 -0500 Subject: [PATCH 13/32] fix pipeline --- third_party/nvidia/backend/compiler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 16703fbeac1c..1dbbdff6c2c7 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -269,7 +269,12 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) if capability // 10 >= 10: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_licm(pm) passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.common.add_canonicalizer(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) passes.ttgpuir.add_combine_tensor_select_and_if(pm) From 94cef6a85c41684202962825a8a1af326b2090c1 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 28 Jan 2025 17:06:54 -0500 Subject: [PATCH 14/32] xd --- python/tutorials/09-persistent-matmul.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 517ffab0c7ef..b916944175df 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -242,9 +242,9 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - num_pid_in_group = GROUP_SIZE_M * num_pid_n offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in range(start_pid, num_tiles, NUM_SMS): group_id = tile_id // num_pid_in_group @@ -295,7 +295,6 @@ def matmul_persistent(a, b): c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) - matmul_kernel_persistent[grid]( a, b, c, # M, N, K, # @@ -339,11 +338,11 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - num_pid_in_group = GROUP_SIZE_M * num_pid_n # tile_id_c is used in the epilogue to break the dependency between # the prologue and the epilogue tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in range(start_pid, num_tiles, NUM_SMS): group_id = tile_id // num_pid_in_group @@ -483,7 +482,6 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - num_pid_in_group = GROUP_SIZE_M * num_pid_n a_desc = tl._experimental_make_tensor_descriptor( a_ptr, @@ -504,6 +502,8 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2], ) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + for tile_id in range(start_pid, num_tiles, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M @@ -522,8 +522,6 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) - c = accumulator.to(dtype) - if EPILOGUE_SUBTILE: acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) acc = tl.permute(acc, (0, 2, 1)) From 62eaecb7742504a99aeba4d35d4bac0e1bc63ad1 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 29 Jan 2025 22:22:33 -0500 Subject: [PATCH 15/32] it actually works --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 213 ++++++++++-------- python/src/ir.cc | 7 + third_party/nvidia/backend/compiler.py | 2 - 3 files changed, 123 insertions(+), 99 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index 5d2ad8eaa4e6..fcbeba68fab8 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -335,6 +335,37 @@ static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) { return b.create(attr); } +static scf::YieldOp getYield(Region &body) { + return cast(body.front().back()); +} + +static scf::IfOp eraseIfResults(ImplicitLocOpBuilder &b, scf::IfOp ifOp, + llvm::BitVector indices, + SmallVector replaceWith) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(ifOp); + while (indices.size() < ifOp.getNumResults()) + indices.push_back(false); + + getYield(ifOp.getThenRegion())->eraseOperands(indices); + getYield(ifOp.getElseRegion())->eraseOperands(indices); + + TypeRange newTypes = getYield(ifOp.getThenRegion()).getOperandTypes(); + auto newIf = b.create(newTypes, ifOp.getCondition()); + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + + SmallVector replacements; + auto replIt = replaceWith.begin(); + auto resIt = newIf->result_begin(); + for (unsigned i : llvm::seq(ifOp.getNumResults())) + replacements.push_back(indices[i] ? *replIt++ : *resIt++); + assert(ValueRange(replacements).getTypes() == ifOp.getResultTypes()); + ifOp.replaceAllUsesWith(replacements); + ifOp.erase(); + return newIf; +} + // Given a one level loop nest in the form // // for i in range(lbi, ubi, stepi): @@ -597,7 +628,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { fusedInits.push_back(createPoisonOrZero(b, resultType)); } unsigned logueOutsStartIdx = fusedInits.size(); - for (Logue &logue : logues) { + for (Logue &logue : llvm::drop_end(logues)) { for (Type outputType : logue.getOutputTypes()) fusedInits.push_back(createPoisonOrZero(b, outputType)); } @@ -729,8 +760,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { // Splice bodyk into the `then` region. inner.getBody()->eraseArguments([](Value arg) { return true; }); bodyIf.getThenRegion().takeBody(inner.getBodyRegion()); - auto yield = - cast(bodyIf.getThenRegion().front().getTerminator()); + auto yield = getYield(bodyIf.getThenRegion()); b.setInsertionPoint(yield); Value nextJk = b.create(jk, inner.getStep()); yield->insertOperands(0, nextJk); @@ -766,44 +796,92 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { b.create(arith::CmpIPredicate::eq, T, b.create(innerLen, intTyCst(1))); auto epilogueIf = - b.create(epilogue.getOutputTypes(), epilogueCond); + b.create(outer.getYieldedValues().getTypes(), epilogueCond); Block *thenBlock = b.createBlock(&epilogueIf.getThenRegion()); epilogue.moveBefore(thenBlock, thenBlock->end()); b.setInsertionPointToEnd(thenBlock); - b.create(epilogue.getOutputs()); + b.create(outer.getYieldedValues()); b.createBlock(&epilogueIf.getElseRegion()); - SmallVector elseOuts(logueOutsIt, - logueOutsIt + epilogue.getNumOutputs()); - b.create(elseOuts); - epilogue.replaceAllUsesWith(epilogueIf.getResults(), - epilogueIf.getThenRegion()); + b.create(fused.getRegionIterArgs().slice( + outerArgsStartIdx, outer.getNumRegionIterArgs())); // Finally, create the yield of the fused loop. SmallVector outerOuts{T, i}; - llvm::append_range(outerOuts, outer.getYieldedValues()); + llvm::append_range(outerOuts, epilogueIf.getResults()); for (scf::IfOp bodyIf : bodyIfs) outerOuts.push_back(/*jk=*/bodyIf.getResult(0)); for (auto [bodyIf, loop] : llvm::zip(bodyIfs, innerLoops)) { llvm::append_range(outerOuts, bodyIf.getResults().slice(1, loop.getNumResults())); - loop.erase(); } for (auto [logueIf, logue] : llvm::zip(prologueIfs, llvm::drop_end(logues))) { llvm::append_range(outerOuts, logueIf.getResults().slice(1, logue.getNumOutputs())); } - llvm::append_range(outerOuts, epilogueIf.getResults()); b.setInsertionPointToEnd(fused.getBody()); - b.create(outerOuts); + auto outerYield = b.create(outerOuts); outer.replaceAllUsesWith( fused.getResults().slice(outerArgsStartIdx, outer.getNumResults())); - outer.erase(); + + // Reduce dependencies across inner loops by hoisting the initialization of + // inner loop iter args to the outer loop when possible, and then placing the + // reset of these values in the epilogue. + auto fusedInitsIt = fused.getInitsMutable().begin() + innerOutsStartIdx; + auto fusedArgsIt = fused.getRegionIterArgs().begin() + innerOutsStartIdx; + auto fusedYieldIt = getYield(fused.getBodyRegion())->getOpOperands().begin() + + innerOutsStartIdx; + SmallVector yieldsToUpdate; + SmallVector reset, forwarded; + for (auto [loop, ifOp, bodyIf, prologue] : + llvm::zip(innerLoops, prologueIfs, bodyIfs, logues)) { + unsigned numResults = loop.getNumResults(); + unsigned prologueSkip = 1 + prologue.getNumOutputs(); + + llvm::BitVector removeIndices(prologueSkip + numResults); + SmallVector replaceWith; + for (auto [i, init] : llvm::enumerate(loop.getInits())) { + if (init.getParentRegion() == &loop.getBodyRegion()) + continue; + // Initialize this in the outer loop. + fusedInitsIt[i].assign(init); + replaceWith.push_back(fusedArgsIt[i]); + removeIndices.set(prologueSkip + i); + yieldsToUpdate.push_back(&fusedYieldIt[i]); + forwarded.push_back(bodyIf.getResult(1 + i)); + reset.push_back(init); + } + // Remove the initializers in the corresponding prologue. + eraseIfResults(b, ifOp, removeIndices, replaceWith); + + fusedInitsIt += numResults; + fusedArgsIt += numResults; + fusedYieldIt += numResults; + } + if (!yieldsToUpdate.empty()) { + MutableOperandRange(getYield(epilogueIf.getThenRegion())).append(reset); + MutableOperandRange(getYield(epilogueIf.getElseRegion())).append(forwarded); + b.setInsertionPoint(epilogueIf); + TypeRange newTypes = getYield(epilogueIf.getThenRegion()).getOperandTypes(); + auto newIf = b.create(newTypes, epilogueIf.getCondition()); + newIf.getThenRegion().takeBody(epilogueIf.getThenRegion()); + newIf.getElseRegion().takeBody(epilogueIf.getElseRegion()); + epilogueIf.replaceAllUsesWith( + newIf.getResults().take_front(epilogueIf.getNumResults())); + ResultRange newResults = + newIf.getResults().drop_front(epilogueIf.getNumResults()); + for (auto [i, yieldOperand] : llvm::enumerate(yieldsToUpdate)) + yieldOperand->set(newResults[i]); + epilogueIf.erase(); + } // Update the parent's loop to the fused loop. + for (scf::ForOp loop : innerLoops) + loop.erase(); + outer.erase(); parent->loop = fused; } @@ -838,18 +916,12 @@ static bool shouldFuse(const LoopNest &nest) { }); } -// Loop-invariant code motion can increase register pressure in combination with -// loop nest fusion. Values hoisted out of the inner loop and in to the prologue -// that are directly used inside the inner loop will need to be added as iter -// args to the fused loop, substantially increasing their liverange. -// -// This function identifies a subgraph of cheap ops that can be sunk and -// determines if doing so will reduce register pressure. +// This function identifies a subgraph of cheap ops that can be sunk between two +// regions in the loop nest and moves them, reducing their liveranges. static void sinkHeavyOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore, llvm::iterator_range prologue, - function_ref inSinkRegion, - function_ref shouldSink) { + function_ref inSinkRegion) { llvm::SetVector sunkOps; auto canBeSunk = [&](Operation &op) -> std::pair { if (!isPure(&op) || op.hasTrait()) @@ -880,94 +952,30 @@ static void sinkHeavyOps(Region &limit, Block *sinkBlock, if (sunkOps.empty()) return; - // Analyze the sinking the whole subgraph at once. Breaking up the subgraph is - // a more complicated analysis. - // - // Compute the total size of the fan-ins and fan-outs as the number of - // registers per thread used by the value. This is a heuristic. - MLIRContext *ctx = sunkOps.front()->getContext(); - auto kRegister = StringAttr::get(ctx, "register"); - auto getSizeEstimate = [&](Type type) { - auto tensor = dyn_cast(type); - if (!tensor) - return 1; - LinearLayout layout = - toLinearLayout(tensor.getShape(), tensor.getEncoding()); - return layout.getInDimSize(kRegister); - }; - - size_t fanOutSize = 0; - for (Operation *root : roots) { - for (Value result : root->getResults()) { - if (result.use_empty()) - continue; - fanOutSize += getSizeEstimate(result.getType()); - } - } - - size_t fanInSize = 0; - DenseSet checked; - for (Operation *op : sunkOps) { - for (Value operand : op->getOperands()) { - // Count each operand only once. - if (!checked.insert(operand).second) - continue; - if (sunkOps.contains(operand.getDefiningOp())) - continue; - if (operand.getParentRegion()->isProperAncestor(&limit)) - continue; - if (llvm::any_of(operand.getUsers(), inSinkRegion)) - continue; - fanInSize += getSizeEstimate(operand.getType()); - } - } - - // Only sink if this will lead to a large reduction. - if (shouldSink(fanInSize, fanOutSize)) { - sunkOps = topologicalSort(sunkOps); - for (Operation *op : sunkOps) - op->moveBefore(sinkBlock, sinkBefore); - } + sunkOps = topologicalSort(sunkOps); + for (Operation *op : sunkOps) + op->moveBefore(sinkBlock, sinkBefore); } -// Sink ops into the inner loop and from the prologue into the epilogue. +// Sink ops from the prologue into the epilogue when possible. static void sinkHeavyOps(scf::ForOp outerLoop, scf::ForOp innerLoop, mlir::DominanceInfo &domInfo) { - Region &limit = outerLoop.getBodyRegion(); - auto inInnerLoop = [&](Operation *op) { - return innerLoop.getBodyRegion().isAncestor(op->getParentRegion()); - }; - // sinkHeavyOps(limit, innerLoop.getBody(), innerLoop.getBody()->begin(), - // {outerLoop.getBody()->begin(), innerLoop->getIterator()}, - // inInnerLoop, [&](size_t fanInSize, size_t fanOutSize) { - // return fanInSize * 4 <= fanOutSize; - // }); - - // Move computations in the prologue that can be done in the epilogue. This is - // always beneficial. auto inEpilogue = [&](Operation *op) { return domInfo.properlyDominates(innerLoop, op, /*enclosingOpOk=*/false); }; + Region &limit = outerLoop.getBodyRegion(); sinkHeavyOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()), {outerLoop.getBody()->begin(), innerLoop->getIterator()}, - inEpilogue, - [&](size_t fanInSize, size_t fanOutSize) { return true; }); + inEpilogue); } // Speculate the length of the inner loop such that the loop is known to execute // at least once. This way, the inner loop body does not have to be placed // inside a conditional in the fused loop, which interacts better with the // pipeliner. -static LogicalResult speculateInnerLoopLength(const LoopNest &nest, +static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop, + scf::ForOp innerLoop, mlir::DominanceInfo &domInfo) { - assert(nest.nodes.size() == 2 && nest.root->children.size() == 1); - - scf::ForOp outerLoop = nest.root->loop; - scf::ForOp innerLoop = nest.root->children.front()->loop; - - // Sink heavy ops first. - sinkHeavyOps(outerLoop, innerLoop, domInfo); - innerLoop->setAttr(kMustExecuteAttrName, UnitAttr::get(outerLoop.getContext())); return success(); @@ -1019,6 +1027,17 @@ static LogicalResult speculateInnerLoopLength(const LoopNest &nest, return success(); } +static LogicalResult preprocessLoopNest(const LoopNest &nest, + mlir::DominanceInfo &domInfo) { + assert(nest.nodes.size() == 2 && nest.root->children.size() == 1); + + scf::ForOp &outerLoop = nest.root->loop; + scf::ForOp &innerLoop = nest.root->children.front()->loop; + + sinkHeavyOps(outerLoop, innerLoop, domInfo); + return speculateInnerLoopLength(outerLoop, innerLoop, domInfo); +} + void FuseNestedLoopsPass::runOnOperation() { auto &domInfo = getAnalysis(); @@ -1029,7 +1048,7 @@ void FuseNestedLoopsPass::runOnOperation() { if (!shouldFuse(nest)) continue; if (!nest.root->loop->hasAttr(kAlwaysFuseAttrName) && - failed(speculateInnerLoopLength(nest, domInfo))) + failed(preprocessLoopNest(nest, domInfo))) continue; flattenLoopNest(nest.root, domInfo); } diff --git a/python/src/ir.cc b/python/src/ir.cc index 53451b706ae1..51be5493bcf6 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1737,6 +1737,13 @@ void init_triton_ir(py::module &&m) { printingFlags); } }) + .def("get_pipeline_str", + [](PassManager &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.printAsTextualPipeline(os); + return str; + }) .def("run", [](PassManager &self, ModuleOp &mod) { // TODO: maybe dump module to file and print error for better // diagnostics diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 1dbbdff6c2c7..e9992201c186 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -273,8 +273,6 @@ def make_ttgir(mod, metadata, opt, capability): passes.common.add_canonicalizer(pm) passes.common.add_licm(pm) passes.ttgpuir.add_optimize_accumulator_init(pm) - passes.common.add_canonicalizer(pm) - passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) passes.ttgpuir.add_combine_tensor_select_and_if(pm) From 00bf302b6eda63ac83b85617a5ac9fe6867126a6 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 29 Jan 2025 23:45:15 -0500 Subject: [PATCH 16/32] fix tests --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 2 +- test/TritonGPU/fuse-nested-loops.mlir | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index fcbeba68fab8..43d3ba05116d 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -844,7 +844,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { llvm::BitVector removeIndices(prologueSkip + numResults); SmallVector replaceWith; for (auto [i, init] : llvm::enumerate(loop.getInits())) { - if (init.getParentRegion() == &loop.getBodyRegion()) + if (init.getParentRegion() == &fused.getBodyRegion()) continue; // Initialize this in the outer loop. fusedInitsIt[i].assign(init); diff --git a/test/TritonGPU/fuse-nested-loops.mlir b/test/TritonGPU/fuse-nested-loops.mlir index 59eecf055073..fb878dfc862a 100644 --- a/test/TritonGPU/fuse-nested-loops.mlir +++ b/test/TritonGPU/fuse-nested-loops.mlir @@ -119,15 +119,14 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK-SAME: [[INOUT:%.*]]: index tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64, %inout: index) -> index { // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] - // CHECK: [[OUTER_OUTS:%.*]]:7 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS:%.*]] step %c1_i64 iter_args( + // CHECK: [[OUTER_OUTS:%.*]]:6 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS:%.*]] step %c1_i64 iter_args( // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]] // CHECK-SAME: [[M:%arg[0-9]+]] = [[INOUT]] // CHECK-SAME: [[J_ARG:%arg[0-9]+]] = %c0_i64 // CHECK-SAME: [[K_ARG:%arg[0-9]+]] = %c0 // CHECK-SAME: [[PROLOGUE_OUT_ARG:%arg[0-9]+]] = %c0 - // CHECK-SAME: [[EPILOGUE_OUT_ARG:%arg[0-9]+]] = %c0 - // CHECK-SAME: ) -> (i64, i64, index, i64, index, index, index) : i64 { + // CHECK-SAME: ) -> (i64, i64, index, i64, index, index) : i64 { %outer_out = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %inout) -> index : i64 { // if T == 0: // i += stepi @@ -172,11 +171,11 @@ tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK-NEXT: [[EPILOGUE_OUT:%.*]] = "epilogue"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#1, [[BODY_OUTS]]#1, [[M]]) : (i64, index, index, index) -> index // CHECK-NEXT: yield [[EPILOGUE_OUT]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[EPILOGUE_OUT_ARG]] + // CHECK-NEXT: yield [[M]] // CHECK-NEXT: } %epilogue_out = "epilogue"(%i, %prologue_out, %inner_out, %m) : (i64, index, index, index) -> index - // CHECK-NEXT: yield %{{.*}}, [[PROLOGUE_OUTS]]#3, [[EPILOGUE_OUTS]], [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1, [[EPILOGUE_OUTS]] : i64, i64, index, i64, index, index, index + // CHECK-NEXT: yield %{{.*}}, [[PROLOGUE_OUTS]]#3, [[EPILOGUE_OUTS]], [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1 : i64, i64, index, i64, index, index scf.yield %epilogue_out : index } {"ttg.always-fuse"} // CHECK: return [[OUTER_OUTS]]#2 @@ -215,7 +214,7 @@ tt.func @multiple_loops( // CHECK-NEXT: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]] // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] - // CHECK: [[OUTS:%.*]]:13 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( + // CHECK: [[OUTS:%.*]]:12 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]], // CHECK-SAME: [[M:%arg[0-9]+]] = [[M0]], @@ -227,8 +226,7 @@ tt.func @multiple_loops( // CHECK-SAME: [[BODY2_ARG:%arg[0-9]+]] = %cst, // CHECK-SAME: [[PROLOGUE0_ARG:%arg[0-9]+]] = %cst, // CHECK-SAME: [[PROLOGUE1_ARG:%arg[0-9]+]] = %cst, - // CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = %cst, - // CHECK-SAME: [[EPILOGUE_ARG:%arg[0-9]+]] = %cst) + // CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = %cst) %mN = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %m0) -> f32 : i64 { // CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64 @@ -314,12 +312,12 @@ tt.func @multiple_loops( // CHECK-NEXT: [[RES:%.*]] = "epilogue"([[PROLOGUE0_OUTS]]#3, [[BODY2_OUTS]]#1) // CHECK-NEXT: yield [[RES]] // CHECK-NEXT: else - // CHECK-NEXT: yield [[EPILOGUE_ARG]] + // CHECK-NEXT: yield [[M]] %out = "epilogue"(%i, %k2N) : (i64, f32) -> f32 // CHECK: scf.yield [[T]], [[PROLOGUE0_OUTS]]#3, [[EPILOGUE_OUTS]], // CHECK-SAME: [[BODY0_OUTS]]#0, [[BODY1_OUTS]]#0, [[BODY2_OUTS]]#0, - // CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1, [[EPILOGUE_OUTS]] : + // CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1 : scf.yield %out : f32 } {"ttg.always-fuse"} // CHECK: return [[OUTS]]#2 From 49061e0d2c0ea3ba172c1ed41e153f21a8737918 Mon Sep 17 00:00:00 2001 From: Mogball Date: Mon, 3 Feb 2025 16:28:06 -0800 Subject: [PATCH 17/32] fmt --- python/tutorials/09-persistent-matmul.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index d8a89fc4ae70..d4a13861c379 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -496,6 +496,9 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2], ) + # tile_id_c is used in the epilogue to break the dependency between + # the prologue and the epilogue + tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in range(start_pid, num_tiles, NUM_SMS): From 93abd41f11ac4b64a5b15a178a7e29a1b6e70750 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 4 Feb 2025 11:43:19 -0800 Subject: [PATCH 18/32] add unit test --- test/TritonGPU/matmul-loop-pipeline.mlir | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir index 1a91ab022a78..55879a739c3a 100644 --- a/test/TritonGPU/matmul-loop-pipeline.mlir +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -47,3 +47,27 @@ tt.func public @scalar_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +tt.func public @_p_matmul_ogs_NNN_fp16xfp16xfp16_128x256x64x1(%arg0: i32, %arg1: !tt.ptr, %arg2: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i64 = arith.constant 1 : i64 + scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 { + %1 = tt.splat %arg1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked> + %2 = tt.load %1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : tensor<128x256x!tt.ptr, #blocked> + %3 = arith.addf %2, %2 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : tensor<128x256xf32, #blocked> + %4 = arith.cmpi eq, %arg3, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32 + scf.if %4 { + %5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > + } {loop.cluster = 5 : i32, loop.stage = 2 : i32} + } {tt.num_stages = 3 : i32} + tt.return +} + +} From 6f9626f479b2d2f83b368346bc5962b87774ea8b Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 4 Feb 2025 11:45:14 -0800 Subject: [PATCH 19/32] fix crash in pipeliner --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 12 ++++++------ test/TritonGPU/matmul-loop-pipeline.mlir | 10 ++++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 25984f477843..0e4503c3e8a7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -703,14 +703,14 @@ scf::IfOp replaceIfOpWithNewSignature( // Create a new loop before the existing one, with the extra operands. auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); resultTypes.append(newResultTypes.begin(), newResultTypes.end()); - scf::IfOp newIf = rewriter.create( - ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true); + scf::IfOp newIf = rewriter.create(ifOp.getLoc(), resultTypes, + ifOp.getCondition()); newIf->setAttrs(ifOp->getAttrs()); - rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(), - newIf.thenBlock()->begin()); - rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(), - newIf.elseBlock()->begin()); + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc()); + scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc()); for (auto it : llvm::zip(ifOp.getResults(), newIf.getResults().take_front(ifOp.getNumResults()))) diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir index 55879a739c3a..8416bf739c49 100644 --- a/test/TritonGPU/matmul-loop-pipeline.mlir +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -52,18 +52,24 @@ tt.func public @scalar_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} { -tt.func public @_p_matmul_ogs_NNN_fp16xfp16xfp16_128x256x64x1(%arg0: i32, %arg1: !tt.ptr, %arg2: i32) { +// CHECK-LABEL: @make_tensor_desc_epilogue +tt.func public @make_tensor_desc_epilogue(%arg0: i32, %arg1: !tt.ptr, %arg2: i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c1_i64 = arith.constant 1 : i64 + // CHECK: scf.for scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 { %1 = tt.splat %arg1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked> %2 = tt.load %1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : tensor<128x256x!tt.ptr, #blocked> %3 = arith.addf %2, %2 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : tensor<128x256xf32, #blocked> %4 = arith.cmpi eq, %arg3, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32 + // CHECK: scf.if scf.if %4 { + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK: tt.experimental_tensormap_create + // CHECK-NEXT: tt.experimental_tensormap_fenceproxy_acquire %5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > } {loop.cluster = 5 : i32, loop.stage = 2 : i32} } {tt.num_stages = 3 : i32} From d352ba61738028d82c2ff95abe6f395495f4ba56 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 4 Feb 2025 14:07:39 -0800 Subject: [PATCH 20/32] fix --- python/tutorials/09-persistent-matmul.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index af13ccf82e8d..c4551dd6565e 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -281,6 +281,7 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator) + tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) From 5eb54818ba8148d4678b22be8ceae9840a1b0797 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 4 Feb 2025 21:20:17 -0800 Subject: [PATCH 21/32] test poison axisinfo --- lib/Analysis/AxisInfo.cpp | 2 +- .../TritonToTritonGPU/TritonGPUConversion.cpp | 22 ++++++----- .../TritonToTritonGPUPass.cpp | 2 + test/Conversion/triton_to_tritongpu.mlir | 9 +++++ test/TritonGPU/coalesce.mlir | 39 +++++++++++++++++++ 5 files changed, 63 insertions(+), 11 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index d51f28273d77..2b544db6b1a8 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -282,7 +282,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { if (auto shape = dyn_cast(op.getType())) { unsigned rank = shape.getRank(); return AxisInfo( - /*contiguity=*/AxisInfo::DimVectorT(rank, 1), + /*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2), /*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2), /*constancy=*/AxisInfo::DimVectorT(shape.getShape())); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 06e75ee18d59..773c01e4a2a0 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -3,6 +3,7 @@ #include #include +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -97,16 +98,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( addDynamicallyLegalDialect([&](Operation *op) { - bool hasLegalRegions = true; - for (auto ®ion : op->getRegions()) { - hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); - } - if (hasLegalRegions && typeConverter.isLegal(op)) { - return true; - } - return false; - }); + scf::SCFDialect, ub::UBDialect>( + [&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); // We have requirements for the data layouts addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 54917a23705c..4b0e8b111216 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -859,6 +860,7 @@ class ConvertTritonToTritonGPU // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); populateCFPatterns(typeConverter, patterns); + patterns.insert>(typeConverter, context); auto inti = llvm::APSInt(32, false); auto i32_ty = IntegerType::get(mod->getContext(), 32); diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index bab044f46b5c..62763e4f056c 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -143,3 +143,12 @@ tt.func @scatter4_layout(%arg0: !tt.tensordesc>, %arg1: i32, % tt.experimental_descriptor_scatter %arg0[%cst, %arg1], %1 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xf32> tt.return } + +// ----- + +// CHECK-LABEL: @ub_poison +tt.func @ub_poison() { + // CHECK-NEXT: ub.poison : tensor<128x64xf16, #blocked> + %0 = ub.poison : tensor<128x64xf16> + tt.return +} diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 25e136514b01..44eb3d47e293 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -160,3 +160,42 @@ module { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + +// CHECK: [[COALESCED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK: @coalesce_poison +tt.func @coalesce_poison(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<128xi32, #blocked1> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %4 = ttg.convert_layout %3 : tensor<128x1xi32, #blocked2> -> tensor<128x1xi32, #blocked3> + %5 = tt.broadcast %4 {axis = 1 : i32} : tensor<128x1xi32, #blocked3> -> tensor<128x64xi32, #blocked3> + %6 = ttg.convert_layout %5 : tensor<128x64xi32, #blocked3> -> tensor<128x64xi32, #blocked> + %7 = tt.addptr %0, %6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + + %8 = ub.poison : tensor<128x64x!tt.ptr, #blocked> + // CHECK: scf.if + %9 = scf.if %arg2 -> (tensor<128x64x!tt.ptr, #blocked>) { + scf.yield %8 : tensor<128x64x!tt.ptr, #blocked> + } else { + scf.yield %7 : tensor<128x64x!tt.ptr, #blocked> + } + // CHECK: [[PTR:%.*]] = ttg.convert_layout %{{.*}} : tensor<128x64x!tt.ptr, #{{.*}}> -> tensor<128x64x!tt.ptr, [[COALESCED_LAYOUT]]> + // CHECK-NEXT: tt.load [[PTR]] + %10 = tt.load %9 : tensor<128x64x!tt.ptr, #blocked> + tt.return +} + +} From cdeacc040123a7504b50ad190b1d0f753e7caf08 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 5 Feb 2025 10:37:23 -0800 Subject: [PATCH 22/32] add tests for everything, switch to flag --- .../TritonGPU/Transforms/FuseNestedLoops.cpp | 53 +++++---- python/src/ir.cc | 4 + python/test/unit/language/test_core.py | 12 +++ python/triton/compiler/code_generator.py | 4 + python/triton/language/core.py | 6 +- python/tutorials/09-persistent-matmul.py | 6 +- test/TritonGPU/fuse-nested-loops.mlir | 102 ++++++++++++++++++ 7 files changed, 160 insertions(+), 27 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index 43d3ba05116d..c15cebbcee3a 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -5,6 +5,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "llvm/Support/Debug.h" #include @@ -19,7 +20,11 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUFUSENESTEDLOOPS #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +// This attribute is set by the front-end to control whether fusion is on. +static constexpr llvm::StringLiteral kFuseAttr = "tt.fuse"; +// This attribute indicates the inner loop length has been speculated. static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute"; +// This attribute is just used for testing the pass. static constexpr llvm::StringLiteral kAlwaysFuseAttrName = "ttg.always-fuse"; namespace { @@ -322,6 +327,10 @@ static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value, return b.create(type, value); } +// To model an "undef" value, i.e. a value that is known to never be read on +// live code paths, create a zero-valued constant where possible, otherwise use +// a poison value. PTXAS appears to generate better code with zeros compared to +// poison values. static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) { Type elTy = getElementTypeOrSelf(type); if (!elTy.isIntOrIndexOrFloat() || @@ -878,11 +887,18 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { epilogueIf.erase(); } - // Update the parent's loop to the fused loop. - for (scf::ForOp loop : innerLoops) + // Update the parent's loop to the fused loop. Set the new stage count to the + // max stage count of the inner loops. + int numStages = 1; + for (scf::ForOp loop : innerLoops) { + if (auto stageAttr = loop->getAttrOfType(kNumStagesAttrName)) + numStages = std::max(numStages, stageAttr.getInt()); loop.erase(); + } outer.erase(); parent->loop = fused; + if (numStages > 1) + fused->setAttr(kNumStagesAttrName, b.getI32IntegerAttr(numStages)); } //===----------------------------------------------------------------------===// @@ -907,21 +923,16 @@ static bool shouldFuse(const LoopNest &nest) { if (nest.root->loop->hasAttr(kAlwaysFuseAttrName)) return true; - if (nest.nodes.size() != 2 || nest.root->children.size() != 1) - return false; - - scf::ForOp innerLoop = nest.root->children.front()->loop; - return llvm::any_of(innerLoop.getOps(), [](Operation &op) { - return op.hasTrait(); - }); + // Only fuse simple loop nests. + return nest.nodes.size() == 2 && nest.root->children.size() == 1 && + nest.root->loop->hasAttr(kFuseAttr); } // This function identifies a subgraph of cheap ops that can be sunk between two // regions in the loop nest and moves them, reducing their liveranges. -static void sinkHeavyOps(Region &limit, Block *sinkBlock, - Block::iterator sinkBefore, - llvm::iterator_range prologue, - function_ref inSinkRegion) { +static void sinkOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore, + llvm::iterator_range prologue, + function_ref inSinkRegion) { llvm::SetVector sunkOps; auto canBeSunk = [&](Operation &op) -> std::pair { if (!isPure(&op) || op.hasTrait()) @@ -958,15 +969,15 @@ static void sinkHeavyOps(Region &limit, Block *sinkBlock, } // Sink ops from the prologue into the epilogue when possible. -static void sinkHeavyOps(scf::ForOp outerLoop, scf::ForOp innerLoop, - mlir::DominanceInfo &domInfo) { +static void optimizeEpilogueDependencies(scf::ForOp outerLoop, + scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { auto inEpilogue = [&](Operation *op) { return domInfo.properlyDominates(innerLoop, op, /*enclosingOpOk=*/false); }; Region &limit = outerLoop.getBodyRegion(); - sinkHeavyOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()), - {outerLoop.getBody()->begin(), innerLoop->getIterator()}, - inEpilogue); + sinkOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()), + {outerLoop.getBody()->begin(), innerLoop->getIterator()}, inEpilogue); } // Speculate the length of the inner loop such that the loop is known to execute @@ -976,10 +987,6 @@ static void sinkHeavyOps(scf::ForOp outerLoop, scf::ForOp innerLoop, static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop, scf::ForOp innerLoop, mlir::DominanceInfo &domInfo) { - innerLoop->setAttr(kMustExecuteAttrName, - UnitAttr::get(outerLoop.getContext())); - return success(); - // The inner loop bounds must be outer-loop invariant to speculate from // outside the loop nest. Location loc = innerLoop.getLoc(); @@ -1034,7 +1041,7 @@ static LogicalResult preprocessLoopNest(const LoopNest &nest, scf::ForOp &outerLoop = nest.root->loop; scf::ForOp &innerLoop = nest.root->children.front()->loop; - sinkHeavyOps(outerLoop, innerLoop, domInfo); + optimizeEpilogueDependencies(outerLoop, innerLoop, domInfo); return speculateInnerLoopLength(outerLoop, innerLoop, domInfo); } diff --git a/python/src/ir.cc b/python/src/ir.cc index 2fc486b18bbb..14fec22e5889 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -458,6 +458,7 @@ void init_triton_ir(py::module &&m) { py::class_(m, "attribute", py::module_local()); py::class_(m, "integer_attr", py::module_local()); py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "unit_attr", py::module_local()); // Ops py::class_(m, "OpState", py::module_local()) @@ -750,6 +751,9 @@ void init_triton_ir(py::module &&m) { self.restoreInsertionPoint(pt); }) // Attr + .def( + "get_unit_attr", + [](TritonOpBuilder &self) { return self.getBuilder().getUnitAttr(); }) .def("get_bool_attr", [](TritonOpBuilder &self, bool value) { return self.getBuilder().getBoolAttr(value); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 93ae3bd35a22..b246f0b94b79 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6575,6 +6575,18 @@ def test_tl_range(device): assert 'cp.async.wait_group 6' in ptx +def test_tl_range_fuse(): + @triton.jit + def kernel(ub): + for i in tl.range(0, ub, fuse=True): + for j in tl.range(0, ub): + print("i", i) + + compiled_kernel = kernel.warmup(10, grid=(1,)) + assert "tt.fuse" in compiled_kernel.asm["ttir"] + assert compiled_kernel.asm["ttgir"].count("scf.for") == 1 + + @triton.jit(noinline=True) def maxnreg_noinline1(X): tl.store(X, 0) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index ae7a0c92e22b..c919423dbf8d 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -998,6 +998,7 @@ def visit_For(self, node): return num_stages = None loop_unroll_factor = None + fuse = None if IteratorClass is language.range: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments @@ -1008,6 +1009,7 @@ def visit_For(self, node): step = iterator.step num_stages = iterator.num_stages loop_unroll_factor = iterator.loop_unroll_factor + fuse = iterator.fuse elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -1082,6 +1084,8 @@ def visit_For(self, node): for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + if fuse: + for_op.set_attr("tt.fuse", self.builder.get_unit_attr()) self.scf_stack.append(node) for_op_body = for_op.get_body(0) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6be9514bd0eb..f7823628907e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2865,9 +2865,12 @@ def kernel(...): :param loop_unroll_factor: Tells the Triton IR level loop unroller how many times to unroll a for loop that this range is used with. Less than 2 for this value implies no unrolling. + :param fuse: automatically fuse the loop nest starting at this loop to + create a single fused loop. The compiler will try to pipeline the fused + loop which can avoid stage stalling. """ - def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, fuse=None): if step is None: self.step = constexpr(1) else: @@ -2880,6 +2883,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact self.end = arg2 self.num_stages = num_stages self.loop_unroll_factor = loop_unroll_factor + self.fuse = fuse def __iter__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index c4551dd6565e..faf504cf68ee 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -260,7 +260,7 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in range(start_pid, num_tiles, NUM_SMS): + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, fuse=True): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N @@ -353,7 +353,7 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in range(start_pid, num_tiles, NUM_SMS): + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, fuse=True): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N @@ -505,7 +505,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in range(start_pid, num_tiles, NUM_SMS): + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, fuse=True): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N diff --git a/test/TritonGPU/fuse-nested-loops.mlir b/test/TritonGPU/fuse-nested-loops.mlir index fb878dfc862a..01937a422cb6 100644 --- a/test/TritonGPU/fuse-nested-loops.mlir +++ b/test/TritonGPU/fuse-nested-loops.mlir @@ -383,6 +383,7 @@ tt.func @cannot_fuse(%lbi: i64, %ubi: i64, %stepi: i64) { "body"(%i, %j) : (i64, i64) -> () } } {"ttg.always-fuse"} + // CHECK-NOT: scf.for tt.return } @@ -438,3 +439,104 @@ tt.func @triple_loop_nest( // CHECK: tt.return tt.return } + +// CHECK-LABEL: @preserve_stage_count +tt.func @preserve_stage_count(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK-COUNT-1: scf.for + scf.for %i = %lb to %ub step %c1_i32 : i32 { + scf.for %j = %lb to %ub step %c1_i32 : i32 { + "body"(%j) : (i32) -> () + scf.yield + } {tt.num_stages = 4 : i32} + scf.for %j = %lb to %ub step %c1_i32 : i32 { + "body"(%j) : (i32) -> () + scf.yield + } {tt.num_stages = 6 : i32} + } {"ttg.always-fuse"} + // CHECK: tt.num_stages = 6 : i32 + // CHECK-NOT: scf.for + tt.return +} + +// CHECK-LABEL: @fuse_attr_speculate +// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32 +tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK: [[DIFF:%.*]] = arith.subi [[UB]], [[LB]] + // CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32 + // CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32 + + // CHECK: scf.if [[IS_ZERO]] + // CHECK-NEXT: scf.for %{{.*}} = [[LB]] to [[UB]] step %c1_i32 + // CHECK-NEXT: "prologue" + // CHECK-NXET: } + + // CHECK: else + // CHECK-COUNT-1: scf.for + // CHECK-NOT: scf.for + scf.for %i = %lb to %ub step %c1_i32 : i32 { + // CHECK: "prologue" + "prologue"(%i) : (i32) -> () + // CHECK: scf.if %true + scf.for %j = %lb to %ub step %c1_i32 : i32 { + // CHECK-NEXT: "body" + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + } {tt.fuse} + tt.return +} + +// CHECK-LABEL: @speculate_hoist +// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32 +tt.func @speculate_hoist(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK: [[UBJ:%.*]] = arith.addi [[LB]], [[UB]] + // CHECK: [[DIFF:%.*]] = arith.subi [[UBJ]], [[LB]] + // CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32 + // CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32 + + // CHECK: scf.if [[IS_ZERO]] + scf.for %i = %lb to %ub step %c1_i32 : i32 { + "prologue"(%i) : (i32) -> () + %ubj = arith.addi %lb, %ub : i32 + scf.for %j = %lb to %ubj step %c1_i32 : i32 { + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + } {tt.fuse} + tt.return +} + +// CHECK-LABEL: @sink_prologue_to_epilogue +// CHECK-SAME: [[UB:%.*]]: i32 +tt.func @sink_prologue_to_epilogue(%ub: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK: else + // CHECK: scf.for + %0 = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%k = %c0_i32) -> i32 : i32 { + // CHECK: [[PROLOGUE_OUTS:%.*]]:2 = scf.if + %0 = arith.addi %i, %ub : i32 + // CHECK: scf.if %true + // CHECK-NEXT: "body" + scf.for %j = %c0_i32 to %ub step %c1_i32 : i32 { + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + // CHECK: scf.if + // CHECK-NEXT: [[V0:%.*]] = arith.addi [[PROLOGUE_OUTS]]#1, [[UB]] + // CHECK-NEXT: [[V1:%.*]] = arith.addi [[V0]], [[UB]] + %1 = arith.addi %0, %ub : i32 + // CHECK-NEXT: "epilogue"([[V1]]) + "epilogue"(%1) : (i32) -> () + scf.yield %0 : i32 + } {tt.fuse} + + tt.return +} From 55b7088a17cbe7dc30d711950ac867411c0c7c7e Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 5 Feb 2025 10:40:30 -0800 Subject: [PATCH 23/32] add licm for earlier archs --- third_party/nvidia/backend/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index ff4134957df2..6fd169683a42 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -231,7 +231,6 @@ def make_ttir(mod, metadata, opt): passes.ttir.add_combine(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) - #passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) pm.run(mod) @@ -268,7 +267,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) - if capability // 10 >= 10: + elif capability // 10 >= 10: passes.ttgpuir.add_fuse_nested_loops(pm) passes.common.add_canonicalizer(pm) passes.common.add_licm(pm) @@ -279,6 +278,8 @@ def make_ttgir(mod, metadata, opt, capability): nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm) passes.common.add_canonicalizer(pm) + elif: + passes.common.add_licm(pm) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.ttgpuir.add_coalesce_async_copy(pm) From e69df1044b082d3f3bc3a5a0d20fc59b8b361c5e Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 5 Feb 2025 10:42:34 -0800 Subject: [PATCH 24/32] fmt --- python/test/unit/language/test_core.py | 3 ++- third_party/nvidia/backend/compiler.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b246f0b94b79..5cbdf3baaf2d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6576,13 +6576,14 @@ def test_tl_range(device): def test_tl_range_fuse(): + @triton.jit def kernel(ub): for i in tl.range(0, ub, fuse=True): for j in tl.range(0, ub): print("i", i) - compiled_kernel = kernel.warmup(10, grid=(1,)) + compiled_kernel = kernel.warmup(10, grid=(1, )) assert "tt.fuse" in compiled_kernel.asm["ttir"] assert compiled_kernel.asm["ttgir"].count("scf.for") == 1 diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6fd169683a42..ecfeb3c91607 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -278,7 +278,7 @@ def make_ttgir(mod, metadata, opt, capability): nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm) passes.common.add_canonicalizer(pm) - elif: + else: passes.common.add_licm(pm) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) From d63f3dabc8621cb7181821e9897284f1151893ce Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 5 Feb 2025 10:54:29 -0800 Subject: [PATCH 25/32] update --- lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index c15cebbcee3a..cd48bb98a59a 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -935,7 +935,7 @@ static void sinkOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore, function_ref inSinkRegion) { llvm::SetVector sunkOps; auto canBeSunk = [&](Operation &op) -> std::pair { - if (!isPure(&op) || op.hasTrait()) + if (!isPure(&op) || isa(op)) return {false, false}; // An op can be sunk if all its users are inside the inner loop or are // marked for sinking. From c86718b87d9f75d22121252e4625c772f2b62693 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 5 Feb 2025 11:50:41 -0800 Subject: [PATCH 26/32] skip test for AMD --- python/test/unit/language/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5cbdf3baaf2d..ee563c449f3e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6576,6 +6576,8 @@ def test_tl_range(device): def test_tl_range_fuse(): + if is_hip(): + pytest.skip("loop fusion is not enabled on AMD") @triton.jit def kernel(ub): From f4ee2fbf364601a9bd59d7e68df5a1e359455d6b Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 6 Feb 2025 10:38:07 -0800 Subject: [PATCH 27/32] extract fix --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 12 ++++---- test/TritonGPU/matmul-loop-pipeline.mlir | 30 -------------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 0e4503c3e8a7..25984f477843 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -703,14 +703,14 @@ scf::IfOp replaceIfOpWithNewSignature( // Create a new loop before the existing one, with the extra operands. auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); resultTypes.append(newResultTypes.begin(), newResultTypes.end()); - scf::IfOp newIf = rewriter.create(ifOp.getLoc(), resultTypes, - ifOp.getCondition()); + scf::IfOp newIf = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true); newIf->setAttrs(ifOp->getAttrs()); - newIf.getThenRegion().takeBody(ifOp.getThenRegion()); - newIf.getElseRegion().takeBody(ifOp.getElseRegion()); - scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc()); - scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc()); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(), + newIf.thenBlock()->begin()); + rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(), + newIf.elseBlock()->begin()); for (auto it : llvm::zip(ifOp.getResults(), newIf.getResults().take_front(ifOp.getNumResults()))) diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir index 8416bf739c49..1a91ab022a78 100644 --- a/test/TritonGPU/matmul-loop-pipeline.mlir +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -47,33 +47,3 @@ tt.func public @scalar_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} { - -// CHECK-LABEL: @make_tensor_desc_epilogue -tt.func public @make_tensor_desc_epilogue(%arg0: i32, %arg1: !tt.ptr, %arg2: i32) { - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c1_i64 = arith.constant 1 : i64 - // CHECK: scf.for - scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 { - %1 = tt.splat %arg1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked> - %2 = tt.load %1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : tensor<128x256x!tt.ptr, #blocked> - %3 = arith.addf %2, %2 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : tensor<128x256xf32, #blocked> - %4 = arith.cmpi eq, %arg3, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32 - // CHECK: scf.if - scf.if %4 { - // CHECK-NOT: tt.make_tensor_descriptor - // CHECK: tt.experimental_tensormap_create - // CHECK-NEXT: tt.experimental_tensormap_fenceproxy_acquire - %5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > - } {loop.cluster = 5 : i32, loop.stage = 2 : i32} - } {tt.num_stages = 3 : i32} - tt.return -} - -} From f19c78c4cb6dc38b66c98d500e6d7cd5a42d5d95 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 6 Feb 2025 10:38:17 -0800 Subject: [PATCH 28/32] fix crash in tma pipeline --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 12 ++++---- test/TritonGPU/matmul-loop-pipeline.mlir | 30 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 25984f477843..0e4503c3e8a7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -703,14 +703,14 @@ scf::IfOp replaceIfOpWithNewSignature( // Create a new loop before the existing one, with the extra operands. auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); resultTypes.append(newResultTypes.begin(), newResultTypes.end()); - scf::IfOp newIf = rewriter.create( - ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true); + scf::IfOp newIf = rewriter.create(ifOp.getLoc(), resultTypes, + ifOp.getCondition()); newIf->setAttrs(ifOp->getAttrs()); - rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(), - newIf.thenBlock()->begin()); - rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(), - newIf.elseBlock()->begin()); + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc()); + scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc()); for (auto it : llvm::zip(ifOp.getResults(), newIf.getResults().take_front(ifOp.getNumResults()))) diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir index 1a91ab022a78..8416bf739c49 100644 --- a/test/TritonGPU/matmul-loop-pipeline.mlir +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -47,3 +47,33 @@ tt.func public @scalar_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} { + +// CHECK-LABEL: @make_tensor_desc_epilogue +tt.func public @make_tensor_desc_epilogue(%arg0: i32, %arg1: !tt.ptr, %arg2: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i64 = arith.constant 1 : i64 + // CHECK: scf.for + scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 { + %1 = tt.splat %arg1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked> + %2 = tt.load %1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : tensor<128x256x!tt.ptr, #blocked> + %3 = arith.addf %2, %2 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : tensor<128x256xf32, #blocked> + %4 = arith.cmpi eq, %arg3, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32 + // CHECK: scf.if + scf.if %4 { + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK: tt.experimental_tensormap_create + // CHECK-NEXT: tt.experimental_tensormap_fenceproxy_acquire + %5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > + } {loop.cluster = 5 : i32, loop.stage = 2 : i32} + } {tt.num_stages = 3 : i32} + tt.return +} + +} From 119fc50e369b03f7194b6ee1c48e9745306bbbf3 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 6 Feb 2025 12:03:15 -0800 Subject: [PATCH 29/32] rename fuse to flatten --- lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp | 4 ++-- python/test/unit/language/test_core.py | 4 ++-- python/triton/compiler/code_generator.py | 8 ++++---- python/triton/language/core.py | 10 +++++----- python/tutorials/09-persistent-matmul.py | 6 +++--- test/TritonGPU/fuse-nested-loops.mlir | 6 +++--- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index cd48bb98a59a..411db4af58a5 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -21,7 +21,7 @@ namespace gpu { #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" // This attribute is set by the front-end to control whether fusion is on. -static constexpr llvm::StringLiteral kFuseAttr = "tt.fuse"; +static constexpr llvm::StringLiteral kFlattenAttr = "tt.flatten"; // This attribute indicates the inner loop length has been speculated. static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute"; // This attribute is just used for testing the pass. @@ -925,7 +925,7 @@ static bool shouldFuse(const LoopNest &nest) { // Only fuse simple loop nests. return nest.nodes.size() == 2 && nest.root->children.size() == 1 && - nest.root->loop->hasAttr(kFuseAttr); + nest.root->loop->hasAttr(kFlattenAttr); } // This function identifies a subgraph of cheap ops that can be sunk between two diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ee563c449f3e..7242f21705a0 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6581,12 +6581,12 @@ def test_tl_range_fuse(): @triton.jit def kernel(ub): - for i in tl.range(0, ub, fuse=True): + for i in tl.range(0, ub, flatten=True): for j in tl.range(0, ub): print("i", i) compiled_kernel = kernel.warmup(10, grid=(1, )) - assert "tt.fuse" in compiled_kernel.asm["ttir"] + assert "tt.flatten" in compiled_kernel.asm["ttir"] assert compiled_kernel.asm["ttgir"].count("scf.for") == 1 diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index c919423dbf8d..02fcdac45c25 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -998,7 +998,7 @@ def visit_For(self, node): return num_stages = None loop_unroll_factor = None - fuse = None + flatten = None if IteratorClass is language.range: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments @@ -1009,7 +1009,7 @@ def visit_For(self, node): step = iterator.step num_stages = iterator.num_stages loop_unroll_factor = iterator.loop_unroll_factor - fuse = iterator.fuse + flatten = iterator.flatten elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -1084,8 +1084,8 @@ def visit_For(self, node): for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) - if fuse: - for_op.set_attr("tt.fuse", self.builder.get_unit_attr()) + if flatten: + for_op.set_attr("tt.flatten", self.builder.get_unit_attr()) self.scf_stack.append(node) for_op_body = for_op.get_body(0) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f7823628907e..0bd7566e3944 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2865,12 +2865,12 @@ def kernel(...): :param loop_unroll_factor: Tells the Triton IR level loop unroller how many times to unroll a for loop that this range is used with. Less than 2 for this value implies no unrolling. - :param fuse: automatically fuse the loop nest starting at this loop to - create a single fused loop. The compiler will try to pipeline the fused - loop which can avoid stage stalling. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. """ - def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, fuse=None): + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, flatten=None): if step is None: self.step = constexpr(1) else: @@ -2883,7 +2883,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact self.end = arg2 self.num_stages = num_stages self.loop_unroll_factor = loop_unroll_factor - self.fuse = fuse + self.flatten = flatten def __iter__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index faf504cf68ee..7d528c80420a 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -260,7 +260,7 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, fuse=True): + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N @@ -353,7 +353,7 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, fuse=True): + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N @@ -505,7 +505,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, fuse=True): + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N diff --git a/test/TritonGPU/fuse-nested-loops.mlir b/test/TritonGPU/fuse-nested-loops.mlir index 01937a422cb6..bcaf031e468d 100644 --- a/test/TritonGPU/fuse-nested-loops.mlir +++ b/test/TritonGPU/fuse-nested-loops.mlir @@ -486,7 +486,7 @@ tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) { "body"(%i, %j) : (i32, i32) -> () scf.yield } - } {tt.fuse} + } {tt.flatten} tt.return } @@ -508,7 +508,7 @@ tt.func @speculate_hoist(%lb: i32, %ub: i32) { "body"(%i, %j) : (i32, i32) -> () scf.yield } - } {tt.fuse} + } {tt.flatten} tt.return } @@ -536,7 +536,7 @@ tt.func @sink_prologue_to_epilogue(%ub: i32) { // CHECK-NEXT: "epilogue"([[V1]]) "epilogue"(%1) : (i32) -> () scf.yield %0 : i32 - } {tt.fuse} + } {tt.flatten} tt.return } From 7aa03029493151d6e42995c6dbdd865a9a50cc9a Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 6 Feb 2025 12:04:02 -0800 Subject: [PATCH 30/32] remove unused var --- lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index 411db4af58a5..a84f9ab77be0 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -307,7 +307,6 @@ static unsigned getIntTypeWidth(Type type) { static Value computeNumIters(ImplicitLocOpBuilder &b, scf::ForOp loop) { // len(range(lb, ub, step)) = ceildiv(ub - lb, step) // This works even if step is negative. - Location loc = loop.getLoc(); Value diff = b.create(loop.getUpperBound(), loop.getLowerBound()); // Let someone else prove it can be unsigned. From 2a263fd35f3e8aa9a396d4316d694f2fea09a68a Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 6 Feb 2025 12:33:17 -0800 Subject: [PATCH 31/32] add an regression test --- test/TritonGPU/pipeline-loop-nest.mlir | 82 ++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 test/TritonGPU/pipeline-loop-nest.mlir diff --git a/test/TritonGPU/pipeline-loop-nest.mlir b/test/TritonGPU/pipeline-loop-nest.mlir new file mode 100644 index 000000000000..aeccc1e1caa7 --- /dev/null +++ b/test/TritonGPU/pipeline-loop-nest.mlir @@ -0,0 +1,82 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(convert-triton-to-tritongpu{num-warps=4 target=cuda:100},tritongpu-coalesce,tritongpu-accelerate-matmul,tritongpu-remove-layout-conversions,tritongpu-optimize-dot-operands,cse,tritongpu-fuse-nested-loops,canonicalize,tritongpu-optimize-accumulator-init,tritongpu-loop-scheduling,tritongpu-pipeline,canonicalize)' | FileCheck %s --check-prefix=BLACKWELL +// RUN: triton-opt %s -pass-pipeline='builtin.module(convert-triton-to-tritongpu{num-warps=4 target=cuda:90 },tritongpu-coalesce,tritongpu-accelerate-matmul,tritongpu-remove-layout-conversions,tritongpu-optimize-dot-operands,cse,tritongpu-fuse-nested-loops,canonicalize,tritongpu-optimize-accumulator-init,canonicalize,tritongpu-combine-tensor-select-and-if,tritongpu-loop-scheduling,tritongpu-pipeline,canonicalize)' | FileCheck %s --check-prefix=HOPPER + +// BLACKWELL-LABEL: @matmul_kernel_tma_persistent +// HOPPER-LABEL: @matmul_kernel_tma_persistent +tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %c63_i32 = arith.constant 63 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c132_i32 = arith.constant 132 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c127_i32 : i32 + %4 = arith.divsi %3, %c128_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.subi %0, %c132_i32 : i32 + %9 = arith.muli %4, %c8_i32 : i32 + + // BLACKWELL: [[ACC_BUFS:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, + // BLACKWELL: ttg.memdesc_trans + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]], %false + + // BLACKWELL: scf.for + %10 = scf.for %arg6 = %0 to %7 step %c132_i32 iter_args(%arg7 = %8) -> (i32) : i32 { + %11 = arith.divsi %arg6, %9 : i32 + %12 = arith.muli %11, %c8_i32 : i32 + %13 = arith.subi %2, %12 : i32 + %14 = arith.minsi %13, %c8_i32 : i32 + %15 = arith.remsi %arg6, %14 : i32 + %16 = arith.addi %12, %15 : i32 + %17 = arith.remsi %arg6, %9 : i32 + %18 = arith.divsi %17, %14 : i32 + %19 = arith.muli %16, %c128_i32 : i32 + %20 = arith.muli %18, %c128_i32 : i32 + %21 = scf.for %arg8 = %c0_i32 to %6 step %c1_i32 iter_args(%arg9 = %cst) -> (tensor<128x128xf32>) : i32 { + %35 = arith.muli %arg8, %c64_i32 : i32 + %36 = tt.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> + %37 = tt.experimental_descriptor_load %36[%19, %35] : !tt.tensordesc> -> tensor<128x64xf16> + %38 = tt.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> + %39 = tt.experimental_descriptor_load %38[%20, %35] : !tt.tensordesc> -> tensor<128x64xf16> + // BLACKWELL: ttg.memdesc_trans + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]], %arg + + // HOPPER: [[RESULT:%.*]] = ttng.warp_group_dot {{.*}} isAsync = true + // HOPPER-NEXT: ttng.warp_group_dot_wait [[RESULT]], {{.*}} {pendings = 1 : i32} + %40 = tt.trans %39 {order = array} : tensor<128x64xf16> -> tensor<64x128xf16> + %41 = tt.dot %37, %40, %arg9, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x128xf16> -> tensor<128x128xf32> + scf.yield %41 : tensor<128x128xf32> + } + // BLACKWELL-COUNT-1: ttng.tmem_load + // BLACKWELL-NOT: ttng.tmem_load + + // HOPPER: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + %22 = arith.addi %arg7, %c132_i32 : i32 + %23 = arith.divsi %22, %9 : i32 + %24 = arith.muli %23, %c8_i32 : i32 + %25 = arith.subi %2, %24 : i32 + %26 = arith.minsi %25, %c8_i32 : i32 + %27 = arith.remsi %22, %26 : i32 + %28 = arith.addi %24, %27 : i32 + %29 = arith.remsi %22, %9 : i32 + %30 = arith.divsi %29, %26 : i32 + %31 = arith.muli %28, %c128_i32 : i32 + %32 = arith.muli %30, %c128_i32 : i32 + %33 = arith.truncf %21 : tensor<128x128xf32> to tensor<128x128xf16> + %34 = tt.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> + tt.experimental_descriptor_store %34[%31, %32], %33 : !tt.tensordesc>, tensor<128x128xf16> + scf.yield %22 : i32 + } {tt.flatten} + tt.return +} + From db7ddbc2b1db0de7ba25b197707e61139da01527 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 6 Feb 2025 13:22:16 -0800 Subject: [PATCH 32/32] fmt --- test/TritonGPU/pipeline-loop-nest.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/test/TritonGPU/pipeline-loop-nest.mlir b/test/TritonGPU/pipeline-loop-nest.mlir index aeccc1e1caa7..c4f9dc5f62c1 100644 --- a/test/TritonGPU/pipeline-loop-nest.mlir +++ b/test/TritonGPU/pipeline-loop-nest.mlir @@ -79,4 +79,3 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.p } {tt.flatten} tt.return } -