From 276c46eaa672ed8e28871c6169b362b39eca0c3b Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 22 Jan 2025 17:47:36 +0000 Subject: [PATCH 1/2] [WIP][LAYOUTS] Remove HoistLayoutConversion in favour of backwardsRemat We now support all layouts as LL, and reductions support any layout as input. As such, at least in theory, we should be able to propagate layouts freely, even DotOperands, similar to what we do with other layouts. This PR is a bit tentative. Let's see if anything interesting breaks --- include/triton/Dialect/Triton/IR/TritonOps.td | 3 +- lib/Dialect/Triton/IR/Ops.cpp | 6 + .../Transforms/OptimizeDotOperands.cpp | 142 --------- .../Transforms/RemoveLayoutConversions.cpp | 133 +++++++- test/TritonGPU/combine.mlir | 289 ++++++++++++++++++ test/TritonGPU/dot-operands.mlir | 256 ---------------- 6 files changed, 422 insertions(+), 407 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 8ace528c0436..2537951099a4 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -830,7 +830,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ Elementwise, SameOperandsAndResultEncoding, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let summary = "inline assembly applying an elementwise operation to a group of packed elements."; let description = [{ diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 521f922114fb..18f41b9dbcf9 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1037,6 +1037,12 @@ void ElementwiseInlineAsmOp::getEffects( SideEffects::DefaultResource::get()); } +Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + LogicalResult ElementwiseInlineAsmOp::verify() { if (getNumOperands() >= 1) { auto tensorType = dyn_cast(getOperand(0).getType()); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5b268b154241..6362e71ef73d 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -24,36 +24,6 @@ namespace { // Roughly, whether op is elementwise and thus threads don't need // to exchange elements. But some ops are not currently supported even though // they meet that criterion. -bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) { - // Only consider custom conversions or arith ops. - // TODO(jlebar): Is this too restrictive? - if (!isa(op) && !isPureUnaryInlineAsm(op) && - !isa(op->getDialect())) - return false; - - // Quick handling to fix loading issues when computing the original - // bitwidth is unable to realize that there is a mixed-precision dot - // (hence kWidth = 1) but wants to hoist through the type conversion. - if (isa(op) && dotOpEnc.getKWidth() == 1) - return false; - - // Currently, these instructions are not supported during lowering of - // shared -> dot_operand layout. Not all types and type conversions are - // supported. - if (isa(op)) - return false; - - // Don't hoist through u1 -> fp casts as they aren't supported in - // ElementwiseOpToLLVM::reorderValues(). - if (isa(op)) { - Type opType = getElementTypeOrSelf(op->getOperand(0)); - if (opType.isInteger(1)) - return false; - } - - return true; -} - // Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A // is in registers). bool canHoistDotOpEncV3(Operation *op) { @@ -195,116 +165,6 @@ class SwizzleShmemConvert : public OpRewritePattern { } }; -// Move convert-to-dot-operand "up" past elementwise ops: -// -// convert(elementwise(x)) #dot_operand -> -// elementwise(convert(x, #dot_operand)). -// -// The goal is to put the convert right next to the originating load. If we can -// accomplish this, then we can save a shmem round-trip: -// -// Before: -// -// - Load from global into shmem using an async copy. -// - Load from shmem into a #blocked layout. -// - Do elementwise ops over #blocked layout. -// - Convert to #dot_operand (round-trip through shmem). -// - Do dot. -// -// After: -// -// - Load from global into shmem using an async copy (same as before). -// - Load from shmem into a #dot_operand layout. -// - Do elementwise ops over #dot_operand layout. -// - Do dot. -// -// This can also be propagated when we have a constant, instead of a load. -// -// Eliminating the shmem round-trip is such a big win, we're willing to do it -// even if this duplicates work because some of the elementwise ops have uses -// that don't flow into the dot. On the other hand, we only want to do this if -// we can in fact reduce shmem round-trips: For example, simply moving a convert -// up above e.g. an `add` now means we have *two* converts. That's worse, -// unless we can continue moving the converts upwards and eventually merge them. -// So we try to check that this will be beneficial before making any changes. -class HoistLayoutConversion : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ConvertLayoutOp cvt, - PatternRewriter &rewriter) const override { - // Only consider conversions to dot operand. - auto cvtTy = cast(cvt.getType()); - auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); - if (!dotOpEnc) - return failure(); - - auto src = cvt.getSrc().getDefiningOp(); - if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) - return failure(); - - auto srcTy = dyn_cast(src->getResult(0).getType()); - if (!srcTy) - return failure(); - - if (!all_of(src->getOperandTypes(), - [](Type ty) { return isa(ty); })) - return failure(); - - if (!canHoistDotOpEncV2(src, dotOpEnc)) - return failure(); - - // Check that the conversion is transitively dependent on a load or a - // constant, and all operations between it and the convert are layout - // preserving. - // - // TODO(jlebar): This is accidentally quadratic; we iterate over the whole - // slice but then at the end we only modify one op! - SetVector slice; - BackwardSliceOptions opt; - opt.omitBlockArguments = true; - getBackwardSlice(cvt.getOperation(), &slice, opt); - - // TODO(jlebar): This is too conservative when there are multiple loads in - // the chain. If one of the loads has a non-layout-preserving op and the - // other does not, then we may or may not accept the chain, depending on - // which load gets hit first by getBackwardSlice. For example: - // cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit. - // cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit. - bool foundInitializer = false; - // Reverse the slice so that we start directly above the convert and check - // that every op allows hoisting until we find a load or a constant. - for (Operation *currOp : llvm::reverse(slice)) { - if (isa(currOp) || isa(currOp)) { - foundInitializer = true; - break; - } - if (!canHoistDotOpEncV2(currOp, dotOpEnc)) - return failure(); - } - if (!foundInitializer) - return failure(); - - SmallVector newOperands; - for (auto operand : src->getOperands()) { - // We checked earlier that all operands are ranked tensors. - auto operandTy = cast(operand.getType()); - Type newCvtTy = RankedTensorType::get( - srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding()); - newOperands.push_back( - rewriter.create(cvt.getLoc(), newCvtTy, operand)); - } - auto newRet = rewriter.clone(*src); - for (int i = 0; i < newOperands.size(); i++) - newRet->setOperand(i, newOperands[i]); - newRet->getResult(0).setType(RankedTensorType::get( - srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding())); - - rewriter.replaceOp(cvt, newRet->getResults()); - return success(); - } -}; - // Rewrite // // dot(alloc(trans() #shared1) -> @@ -699,8 +559,6 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); - if (this->hoistLayoutConversion.getValue()) - patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 903bebd3db12..60b860f8f2b5 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -129,6 +130,9 @@ class LayoutRematerialization { void cleanup(); void backwardRematerialization(); void backwardRematerialization(ConvertLayoutOp convertOp); + // TODO: Merge the three hoistConvert*(); functions as they are duplicate code + void hoistConvertDotOperand(); + void hoistConvertDotOperand(ConvertLayoutOp convertOp); void hoistConvertOnTopOfExtOrBroadcast(); void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); void hoistConvertIntoConditionals(); @@ -138,6 +142,12 @@ class LayoutRematerialization { void rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp); + LogicalResult + getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding, + SetVector &slice, + DenseMap &layout, + std::function stopPropagation); + LogicalResult getRematerializableSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, @@ -948,7 +958,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, rewriteSlice(slice, layout, convertOp, mapping); } -LogicalResult LayoutRematerialization::getRematerializableSlice( +LogicalResult LayoutRematerialization::getConvertBackwardSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, std::function stopPropagation) { @@ -975,9 +985,17 @@ LogicalResult LayoutRematerialization::getRematerializableSlice( } return Value(); }; - LogicalResult result = - getConvertBackwardSlice(root, slice, rootEncoding, layout, - stopPropagation, getExistingConversion); + + return mlir::getConvertBackwardSlice(root, slice, rootEncoding, layout, + stopPropagation, getExistingConversion); +} + +LogicalResult LayoutRematerialization::getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation) { + LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice, + layout, stopPropagation); if (result.failed() || slice.empty()) return failure(); @@ -1041,8 +1059,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals() { void LayoutRematerialization::backwardRematerialization( ConvertLayoutOp convertOp) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristic to accommodate fused attention + // DotOperand is hoisted by hoistDotOperand RankedTensorType targetType = convertOp.getType(); if (isa(targetType.getEncoding())) return; @@ -1080,12 +1097,108 @@ void LayoutRematerialization::backwardRematerialization( rewriteSlice(slice, layout, convertOp); } +void LayoutRematerialization::hoistConvertDotOperand() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertDotOperand(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertDotOperand( + ConvertLayoutOp convertOp) { + auto targetType = convertOp.getType(); + // The pass is targeted to Nvidia mma/wgmma dot operands + // We move convert #dot_operand next to their loads. This is done + // so that it's then easy to pipeline these loads + // TODO: Perhaps we should do this whenever convertOp is within a loop + + auto dotEnc = dyn_cast(targetType.getEncoding()); + if (!(dotEnc && isa(dotEnc.getParent()))) + return; + + // We hoist over any operation that can be done without data movement between + // threads We do views and elementwise pure ops for now + // UpcastMXFPOp is here temporarily until + // https://github.com/triton-lang/triton/pull/5475 lands + auto noDataMovement = [](Operation *op) { + return (op->hasTrait() && isMemoryEffectFree(op)) || + isa(op); + }; + // Stop the slice as soon as we find an operation that cannot be done without + // data movement between threads + auto stop = std::not_fn(noDataMovement); + + SetVector slice; + DenseMap layout; + // Set-up the conversion "cache" + LogicalResult result = getConvertBackwardSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop); + if (result.failed()) + return; + + IRMapping mapping; + OpBuilder builder(convertOp.getContext()); + SetVector innerSlice; + for (Value v : slice) { + if (!v.getDefiningOp()) { + LLVM_DEBUG( + { DBGS() << " Block arguments not supported. Got " << v << "\n"; }); + return; + } + auto loadOp = dyn_cast(v.getDefiningOp()); + // We expect the leaves of the slice to be Load or arith::Constant + // This could be generalised if necessary + if (!loadOp) { + auto op = v.getDefiningOp(); + if (isa(op) || noDataMovement(op)) { + innerSlice.insert(v); + continue; + } else { + LLVM_DEBUG({ + DBGS() << " Leaves must be Load or Constant. Got " << v << "\n"; + }); + return; + } + } + builder.setInsertionPointAfter(loadOp); + auto type = dyn_cast(loadOp.getType()); + if (!type) + continue; + auto newType = RankedTensorType::get(type.getShape(), type.getElementType(), + layout[loadOp]); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, loadOp.getResult()); + mapping.map(loadOp.getResult(), newConvertOp.getResult()); + } + + if (innerSlice.empty()) { + return; + } + + LLVM_DEBUG({ + DBGS() << " Hoisting " << convertOp << '\n'; + for (Value v : innerSlice) + DBGS() << " " << v << '\n'; + }); + + rewriteSlice(innerSlice, layout, convertOp, mapping); +} + // For convert left we try to hoist them above type extension to reduce the cost // of the convert. void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention + // DotOperand is hoisted by hoistDotOperand RankedTensorType targetType = convertOp.getType(); if (isa(targetType.getEncoding())) return; @@ -1337,6 +1450,10 @@ void hoistConvert(ModuleOp module) { layoutRemat = LayoutRematerialization(funcOp); layoutRemat.hoistConvertIntoConditionals(); layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertDotOperand(); + layoutRemat.cleanup(); }); } } // namespace diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 93e15722645c..4471725a6f69 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -3150,3 +3150,292 @@ tt.func public @reshape_slice_dot_enc(%arg0: tensor<4x16xi32, #blocked>) -> tens } } +#Cv2 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#Av2k1 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> +#Bv2k1 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> +#Av2k2 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> +#Bv2k2 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> +#Av2k4 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> +#Bv2k4 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> +#ALR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> +#BLR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: tt.func @push_elementwise +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> +// CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]] +// CHECK-SAME: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_elementwise( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> + %af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> + %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> + %dota = ttg.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + + +// CHECK: tt.func @succeeds_if_arg_is_not_convert_layout +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @succeeds_if_arg_is_not_convert_layout( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %dotai8 = ttg.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> + %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> + %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + +// CHECK: tt.func @push_inline_asm_op +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.elementwise_inline_asm {{.*}} %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_inline_asm_op( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %dotb: tensor<16x16xf16, #Bv2k4>, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %dotaf8 = tt.bitcast %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> + %dota = tt.elementwise_inline_asm "{ cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; }" {constraints = "=r,r", packed_element = 2 : i32, pure = true} %dotaf8 : tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> + %dota_cvt = ttg.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %newc = tt.dot %dota_cvt, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} +} + +// ----- + +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> + +// CHECK: tt.func @push_convert_both_operands +// CHECK-DAG: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> +// CHECK-DAG: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> +// CHECK-DAG: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +tt.func @push_convert_both_operands( + %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> + %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> + %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_mmav3_to_constant() + // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + tt.func @propagate_dot_op_mmav3_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> { + %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + } +} + +// ----- + +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> + +// CHECK: tt.func @update_kwidth_slice +// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> +// CHECK-DAG: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> +// CHECK-DAG: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +tt.func @update_kwidth_slice( + %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + %cst = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blockedB> + %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> + %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> + %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> + %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_to_constant() + // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + tt.func @propagate_dot_op_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> { + %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_to_constant_above_for() + // CHECK: arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + tt.func @propagate_dot_op_to_constant_above_for() -> tensor<32x128xf32, #mma> { + %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %loop:1 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst_1) -> (tensor<32x128xf32, #mma>) : i32 { + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %3 = tt.dot %2, %1, %arg0, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> + scf.yield %3 : tensor<32x128xf32, #mma> + } + tt.return %loop#0 : tensor<32x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // We currently don't propagate through block arguments on hoistDotOperand + // that being said, https://github.com/triton-lang/triton/pull/5350 + // allowed to lift DotOperand(opIdx=1), which might be alright + + // CHECK: tt.func @do_not_propagate_through_block_arguments() + // CHECK: %[[THROUGH_FOR_OP:.*]] = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[THROUGH_FOR_OP]], + tt.func @do_not_propagate_through_block_arguments() -> tensor<32x128xf32, #mma> { + %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %loop:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst, %arg1 = %cst_1) -> (tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma>) : i32 { + %0 = arith.addf %cst, %arg0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %3 = tt.dot %2, %1, %arg1, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> + scf.yield %0, %3 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma> + } + tt.return %loop#1 : tensor<32x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( + %pa: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice + // This checks that we propagate dot op layout given the following: + // initializer -> unsupported op -> initializer -> supported ops -> convert, + // where initializers can be constants or loads. + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + %offset = arith.constant dense<16> : tensor<16x1xi32, #blocked> + %broadcast = tt.broadcast %offset : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> + %pa2 = tt.addptr %pa, %broadcast : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> + %a = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> + %ae = arith.extf %a : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> + %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( + %pa1: tensor<16x1x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pa2: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice + // Confirm that both loads feed directly into a convert_layout. + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + // CHECK: %[[LOAD2:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD2]] + %a1 = tt.load %pa1 : tensor<16x1x!tt.ptr, #blocked> + %a2 = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> + %ab = tt.broadcast %a1 : tensor<16x1xf16, #blocked> -> tensor<16x16xf16, #blocked> + %aa = arith.addf %ab, %a2 : tensor<16x16xf16, #blocked> + %ae = arith.extf %aa : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> + %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> + } +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index c18b2d222ccc..97b87f207660 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -1,158 +1,9 @@ // RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -canonicalize | FileCheck %s -#Cv2 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#Av2k1 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> -#Bv2k1 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> -#Av2k2 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> -#Bv2k2 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> -#Av2k4 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> -#Bv2k4 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> -#ALR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> -#BLR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#BLC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> - -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - -// CHECK: tt.func @push_elementwise -// CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[BCVT:.*]] = ttg.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -// CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]] -// CHECK-SAME: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> -// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> -tt.func @push_elementwise( - %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ - %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> - %af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> - %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota = ttg.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> - %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> - tt.return %newc : tensor<16x16xf32, #Cv2> -} - - -// CHECK: tt.func @succeeds_if_arg_is_not_convert_layout -// CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] -// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] -// CHECK: %[[C:.*]] = tt.dot %[[AF16]] -// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> -tt.func @succeeds_if_arg_is_not_convert_layout( - %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ - %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %dotai8 = ttg.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> - %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> - %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> - %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> - tt.return %newc : tensor<16x16xf32, #Cv2> -} - -// CHECK: tt.func @push_inline_asm_op -// CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] -// CHECK: %[[AF16:.*]] = tt.elementwise_inline_asm {{.*}} %[[AF8E5]] -// CHECK: %[[C:.*]] = tt.dot %[[AF16]] -// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> -tt.func @push_inline_asm_op( - %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %dotb: tensor<16x16xf16, #Bv2k4>, - %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ - %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %dotaf8 = tt.bitcast %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> - %dota = tt.elementwise_inline_asm "{ cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; }" {constraints = "=r,r", packed_element = 2 : i32, pure = true} %dotaf8 : tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota_cvt = ttg.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> - %newc = tt.dot %dota_cvt, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> - tt.return %newc : tensor<16x16xf32, #Cv2> -} - -} - -// ----- #blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - -// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> - -// CHECK: tt.func @push_convert_both_operands -// CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> -// CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> -tt.func @push_convert_both_operands( - %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ - %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> - %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> - %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> - %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> - tt.return %r : tensor<16x16xf32, #mma> -} - -} - -// ----- - -#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - -// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> - -// CHECK: tt.func @update_kwidth_slice -// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> -// CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> -tt.func @update_kwidth_slice( - %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ - %cst = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blockedB> - %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> - %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> - %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> - %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> - %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> - tt.return %r : tensor<16x16xf32, #mma> -} - -} // ----- @@ -305,113 +156,6 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @propagate_dot_op_to_constant() - // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - tt.func @propagate_dot_op_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> { - %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @propagate_dot_op_mmav3_to_constant() - // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - tt.func @propagate_dot_op_mmav3_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> { - %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @propagate_dot_op_to_constant_above_for() - // CHECK: arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - tt.func @propagate_dot_op_to_constant_above_for() -> tensor<32x128xf32, #mma> { - %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %c128_i32 = arith.constant 128 : i32 - %loop:1 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst_1) -> (tensor<32x128xf32, #mma>) : i32 { - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %3 = tt.dot %2, %1, %arg0, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> - scf.yield %3 : tensor<32x128xf32, #mma> - } - tt.return %loop#0 : tensor<32x128xf32, #mma> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @do_not_propagate_through_block_arguments() - // CHECK: %[[THROUGH_FOR_OP:.*]] = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[THROUGH_FOR_OP]], - tt.func @do_not_propagate_through_block_arguments() -> tensor<32x128xf32, #mma> { - %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %c128_i32 = arith.constant 128 : i32 - %loop:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst, %arg1 = %cst_1) -> (tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma>) : i32 { - %0 = arith.addf %cst, %arg0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %3 = tt.dot %2, %1, %arg1, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> - scf.yield %0, %3 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma> - } - tt.return %loop#1 : tensor<32x128xf32, #mma> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( - %pa: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, - %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ - // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice - // This checks that we propagate dot op layout given the following: - // initializer -> unsupported op -> initializer -> supported ops -> convert, - // where initializers can be constants or loads. - // CHECK: %[[LOAD1:.*]] = tt.load - // CHECK: ttg.convert_layout %[[LOAD1]] - %offset = arith.constant dense<16> : tensor<16x1xi32, #blocked> - %broadcast = tt.broadcast %offset : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> - %pa2 = tt.addptr %pa, %broadcast : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> - %a = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> - %ae = arith.extf %a : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> - %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> - tt.return %r : tensor<16x16xf32, #mma> - } -} - -// ----- - #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> #smem = #ttg.shared_memory #blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> From c4bf9cd74ecf1f78cfa1916b7ca0c62ec84cda8c Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 31 Jan 2025 10:58:13 +0000 Subject: [PATCH 2/2] Make MemDescTransOp accept equivalent LLs --- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 2 +- lib/Dialect/TritonGPU/IR/Ops.cpp | 28 ++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index ea5ed593dff3..833007947a1d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, TransposeOpInterface, - DeclareOpInterfaceMethods, + InferTypeOpWithLayoutEquivalence, SameOperandsAndResultElementType]> { let summary = "transpose the descriptor"; diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 9222cce69eb6..32f96d574250 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -463,15 +463,17 @@ OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) { return {}; } -LogicalResult MemDescTransOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +MemDescTransOp::inferReturnTypes(MLIRContext *context, + std::optional location, + MemDescTransOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input - auto argTy = cast(operands[0].getType()); - auto argShape = argTy.getShape(); - auto order = properties.as()->order.asArrayRef(); - SmallVector retShape = applyPermutation(argTy.getShape(), order); + auto argTy = cast(adaptor.getSrc().getType()); + auto shape = argTy.getShape(); + auto order = adaptor.getOrder(); + SmallVector retShape = applyPermutation(shape, order); auto retEltTy = argTy.getElementType(); Attribute argEncoding = argTy.getEncoding(); @@ -480,17 +482,17 @@ LogicalResult MemDescTransOp::inferReturnTypes( Dialect &dialect = argEncoding.getDialect(); auto inferLayoutInterface = cast(&dialect); if (inferLayoutInterface - ->inferTransOpEncoding(argEncoding, argShape, order, retEncoding) + ->inferTransOpEncoding(argEncoding, shape, order, retEncoding) .failed()) { return failure(); } } - auto memDescTy = cast(argTy); - inferredReturnTypes.push_back(MemDescType::get( - retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), - memDescTy.getMutableMemory())); + inferredReturnTypes.push_back( + MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(), + argTy.getMutableMemory())); return success(); } + // LocalAllocOp void LocalAllocOp::getEffects( SmallVectorImpl>