diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5b268b154241..6e93f87a451b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -138,8 +138,8 @@ auto cloneSlice(PatternRewriter &rewriter, } // Given -// convert(trans(src)) #dot_operand -> -// convert(local_load(trans(alloc(src)))) +// dot(convert(trans(src)) #dot_operand) -> +// dot(convert(local_load(trans(alloc(src))))) // change the encoding of the inner convert to a special, swizzled shared // encoding. class SwizzleShmemConvert : public OpRewritePattern { @@ -148,6 +148,9 @@ class SwizzleShmemConvert : public OpRewritePattern { LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, PatternRewriter &rewriter) const override { + if (!cvtOp->hasOneUse() || + !isa(cvtOp->use_begin()->getOwner())) + return failure(); // Match outerCvt(trans(innerCvt(x))). auto trans = cvtOp.getSrc().getDefiningOp(); if (!trans || trans.getOrder() != ArrayRef{1, 0}) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index c18b2d222ccc..ac441555d30c 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -305,6 +305,28 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: mmav2_transpose_indirect +// CHECK: tt.trans +// CHECK: ttg.convert_layout +// CHECK: arith.addf +// CHECK: tt.dot + tt.func @mmav2_transpose_indirect(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a = tt.trans %t {order = array} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked> + %cv = ttg.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %add = arith.addf %cv, %cst : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %add, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {