Skip to content

Commit

Permalink
Do not reorder transpose of dot operand that is used in ops other tha…
Browse files Browse the repository at this point in the history
…n dotOp (#5686)

If the result of `convert_layout(trans(x), #dot_operand)` is not used by
`tt.dot`, skip pattern match that generates `memdesc_trans`. Without
explicitly going through shared memory, it will be easier to pipeline
such cases for mxfp.
  • Loading branch information
pawelszczerbuk authored Jan 31, 2025
1 parent 0af2f62 commit 7335fcc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
7 changes: 5 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ auto cloneSlice(PatternRewriter &rewriter,
}

// Given
// convert(trans(src)) #dot_operand ->
// convert(local_load(trans(alloc(src))))
// dot(convert(trans(src)) #dot_operand) ->
// dot(convert(local_load(trans(alloc(src)))))
// change the encoding of the inner convert to a special, swizzled shared
// encoding.
class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
Expand All @@ -148,6 +148,9 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {

LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp,
PatternRewriter &rewriter) const override {
if (!cvtOp->hasOneUse() ||
!isa<triton::DotOp>(cvtOp->use_begin()->getOwner()))
return failure();
// Match outerCvt(trans(innerCvt(x))).
auto trans = cvtOp.getSrc().getDefiningOp<TransOp>();
if (!trans || trans.getOrder() != ArrayRef<int32_t>{1, 0})
Expand Down
22 changes: 22 additions & 0 deletions test/TritonGPU/dot-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,28 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-

// -----

#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
#smem = #ttg.shared_memory
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: mmav2_transpose_indirect
// CHECK: tt.trans
// CHECK: ttg.convert_layout
// CHECK: arith.addf
// CHECK: tt.dot
tt.func @mmav2_transpose_indirect(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%a = tt.trans %t {order = array<i32: 1, 0>} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked>
%cv = ttg.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%add = arith.addf %cv, %cst : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%r = tt.dot %add, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
tt.return %r : tensor<128x64xf32, #mma>
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
Expand Down

0 comments on commit 7335fcc

Please sign in to comment.