diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 8ace528c0436..2537951099a4 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -830,7 +830,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ Elementwise, SameOperandsAndResultEncoding, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let summary = "inline assembly applying an elementwise operation to a group of packed elements."; let description = [{ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index ea5ed593dff3..833007947a1d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, TransposeOpInterface, - DeclareOpInterfaceMethods, + InferTypeOpWithLayoutEquivalence, SameOperandsAndResultElementType]> { let summary = "transpose the descriptor"; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 521f922114fb..18f41b9dbcf9 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1037,6 +1037,12 @@ void ElementwiseInlineAsmOp::getEffects( SideEffects::DefaultResource::get()); } +Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + LogicalResult ElementwiseInlineAsmOp::verify() { if (getNumOperands() >= 1) { auto tensorType = dyn_cast(getOperand(0).getType()); diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 9222cce69eb6..32f96d574250 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -463,15 +463,17 @@ OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) { return {}; } -LogicalResult MemDescTransOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +MemDescTransOp::inferReturnTypes(MLIRContext *context, + std::optional location, + MemDescTransOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input - auto argTy = cast(operands[0].getType()); - auto argShape = argTy.getShape(); - auto order = properties.as()->order.asArrayRef(); - SmallVector retShape = applyPermutation(argTy.getShape(), order); + auto argTy = cast(adaptor.getSrc().getType()); + auto shape = argTy.getShape(); + auto order = adaptor.getOrder(); + SmallVector retShape = applyPermutation(shape, order); auto retEltTy = argTy.getElementType(); Attribute argEncoding = argTy.getEncoding(); @@ -480,17 +482,17 @@ LogicalResult MemDescTransOp::inferReturnTypes( Dialect &dialect = argEncoding.getDialect(); auto inferLayoutInterface = cast(&dialect); if (inferLayoutInterface - ->inferTransOpEncoding(argEncoding, argShape, order, retEncoding) + ->inferTransOpEncoding(argEncoding, shape, order, retEncoding) .failed()) { return failure(); } } - auto memDescTy = cast(argTy); - inferredReturnTypes.push_back(MemDescType::get( - retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), - memDescTy.getMutableMemory())); + inferredReturnTypes.push_back( + MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(), + argTy.getMutableMemory())); return success(); } + // LocalAllocOp void LocalAllocOp::getEffects( SmallVectorImpl> diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5b268b154241..6362e71ef73d 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -24,36 +24,6 @@ namespace { // Roughly, whether op is elementwise and thus threads don't need // to exchange elements. But some ops are not currently supported even though // they meet that criterion. -bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) { - // Only consider custom conversions or arith ops. - // TODO(jlebar): Is this too restrictive? - if (!isa(op) && !isPureUnaryInlineAsm(op) && - !isa(op->getDialect())) - return false; - - // Quick handling to fix loading issues when computing the original - // bitwidth is unable to realize that there is a mixed-precision dot - // (hence kWidth = 1) but wants to hoist through the type conversion. - if (isa(op) && dotOpEnc.getKWidth() == 1) - return false; - - // Currently, these instructions are not supported during lowering of - // shared -> dot_operand layout. Not all types and type conversions are - // supported. - if (isa(op)) - return false; - - // Don't hoist through u1 -> fp casts as they aren't supported in - // ElementwiseOpToLLVM::reorderValues(). - if (isa(op)) { - Type opType = getElementTypeOrSelf(op->getOperand(0)); - if (opType.isInteger(1)) - return false; - } - - return true; -} - // Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A // is in registers). bool canHoistDotOpEncV3(Operation *op) { @@ -195,116 +165,6 @@ class SwizzleShmemConvert : public OpRewritePattern { } }; -// Move convert-to-dot-operand "up" past elementwise ops: -// -// convert(elementwise(x)) #dot_operand -> -// elementwise(convert(x, #dot_operand)). -// -// The goal is to put the convert right next to the originating load. If we can -// accomplish this, then we can save a shmem round-trip: -// -// Before: -// -// - Load from global into shmem using an async copy. -// - Load from shmem into a #blocked layout. -// - Do elementwise ops over #blocked layout. -// - Convert to #dot_operand (round-trip through shmem). -// - Do dot. -// -// After: -// -// - Load from global into shmem using an async copy (same as before). -// - Load from shmem into a #dot_operand layout. -// - Do elementwise ops over #dot_operand layout. -// - Do dot. -// -// This can also be propagated when we have a constant, instead of a load. -// -// Eliminating the shmem round-trip is such a big win, we're willing to do it -// even if this duplicates work because some of the elementwise ops have uses -// that don't flow into the dot. On the other hand, we only want to do this if -// we can in fact reduce shmem round-trips: For example, simply moving a convert -// up above e.g. an `add` now means we have *two* converts. That's worse, -// unless we can continue moving the converts upwards and eventually merge them. -// So we try to check that this will be beneficial before making any changes. -class HoistLayoutConversion : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ConvertLayoutOp cvt, - PatternRewriter &rewriter) const override { - // Only consider conversions to dot operand. - auto cvtTy = cast(cvt.getType()); - auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); - if (!dotOpEnc) - return failure(); - - auto src = cvt.getSrc().getDefiningOp(); - if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) - return failure(); - - auto srcTy = dyn_cast(src->getResult(0).getType()); - if (!srcTy) - return failure(); - - if (!all_of(src->getOperandTypes(), - [](Type ty) { return isa(ty); })) - return failure(); - - if (!canHoistDotOpEncV2(src, dotOpEnc)) - return failure(); - - // Check that the conversion is transitively dependent on a load or a - // constant, and all operations between it and the convert are layout - // preserving. - // - // TODO(jlebar): This is accidentally quadratic; we iterate over the whole - // slice but then at the end we only modify one op! - SetVector slice; - BackwardSliceOptions opt; - opt.omitBlockArguments = true; - getBackwardSlice(cvt.getOperation(), &slice, opt); - - // TODO(jlebar): This is too conservative when there are multiple loads in - // the chain. If one of the loads has a non-layout-preserving op and the - // other does not, then we may or may not accept the chain, depending on - // which load gets hit first by getBackwardSlice. For example: - // cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit. - // cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit. - bool foundInitializer = false; - // Reverse the slice so that we start directly above the convert and check - // that every op allows hoisting until we find a load or a constant. - for (Operation *currOp : llvm::reverse(slice)) { - if (isa(currOp) || isa(currOp)) { - foundInitializer = true; - break; - } - if (!canHoistDotOpEncV2(currOp, dotOpEnc)) - return failure(); - } - if (!foundInitializer) - return failure(); - - SmallVector newOperands; - for (auto operand : src->getOperands()) { - // We checked earlier that all operands are ranked tensors. - auto operandTy = cast(operand.getType()); - Type newCvtTy = RankedTensorType::get( - srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding()); - newOperands.push_back( - rewriter.create(cvt.getLoc(), newCvtTy, operand)); - } - auto newRet = rewriter.clone(*src); - for (int i = 0; i < newOperands.size(); i++) - newRet->setOperand(i, newOperands[i]); - newRet->getResult(0).setType(RankedTensorType::get( - srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding())); - - rewriter.replaceOp(cvt, newRet->getResults()); - return success(); - } -}; - // Rewrite // // dot(alloc(trans() #shared1) -> @@ -699,8 +559,6 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); - if (this->hoistLayoutConversion.getValue()) - patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 903bebd3db12..60b860f8f2b5 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -129,6 +130,9 @@ class LayoutRematerialization { void cleanup(); void backwardRematerialization(); void backwardRematerialization(ConvertLayoutOp convertOp); + // TODO: Merge the three hoistConvert*(); functions as they are duplicate code + void hoistConvertDotOperand(); + void hoistConvertDotOperand(ConvertLayoutOp convertOp); void hoistConvertOnTopOfExtOrBroadcast(); void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); void hoistConvertIntoConditionals(); @@ -138,6 +142,12 @@ class LayoutRematerialization { void rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp); + LogicalResult + getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding, + SetVector &slice, + DenseMap &layout, + std::function stopPropagation); + LogicalResult getRematerializableSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, @@ -948,7 +958,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, rewriteSlice(slice, layout, convertOp, mapping); } -LogicalResult LayoutRematerialization::getRematerializableSlice( +LogicalResult LayoutRematerialization::getConvertBackwardSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, std::function stopPropagation) { @@ -975,9 +985,17 @@ LogicalResult LayoutRematerialization::getRematerializableSlice( } return Value(); }; - LogicalResult result = - getConvertBackwardSlice(root, slice, rootEncoding, layout, - stopPropagation, getExistingConversion); + + return mlir::getConvertBackwardSlice(root, slice, rootEncoding, layout, + stopPropagation, getExistingConversion); +} + +LogicalResult LayoutRematerialization::getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation) { + LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice, + layout, stopPropagation); if (result.failed() || slice.empty()) return failure(); @@ -1041,8 +1059,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals() { void LayoutRematerialization::backwardRematerialization( ConvertLayoutOp convertOp) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristic to accommodate fused attention + // DotOperand is hoisted by hoistDotOperand RankedTensorType targetType = convertOp.getType(); if (isa(targetType.getEncoding())) return; @@ -1080,12 +1097,108 @@ void LayoutRematerialization::backwardRematerialization( rewriteSlice(slice, layout, convertOp); } +void LayoutRematerialization::hoistConvertDotOperand() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertDotOperand(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertDotOperand( + ConvertLayoutOp convertOp) { + auto targetType = convertOp.getType(); + // The pass is targeted to Nvidia mma/wgmma dot operands + // We move convert #dot_operand next to their loads. This is done + // so that it's then easy to pipeline these loads + // TODO: Perhaps we should do this whenever convertOp is within a loop + + auto dotEnc = dyn_cast(targetType.getEncoding()); + if (!(dotEnc && isa(dotEnc.getParent()))) + return; + + // We hoist over any operation that can be done without data movement between + // threads We do views and elementwise pure ops for now + // UpcastMXFPOp is here temporarily until + // https://github.com/triton-lang/triton/pull/5475 lands + auto noDataMovement = [](Operation *op) { + return (op->hasTrait() && isMemoryEffectFree(op)) || + isa(op); + }; + // Stop the slice as soon as we find an operation that cannot be done without + // data movement between threads + auto stop = std::not_fn(noDataMovement); + + SetVector slice; + DenseMap layout; + // Set-up the conversion "cache" + LogicalResult result = getConvertBackwardSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop); + if (result.failed()) + return; + + IRMapping mapping; + OpBuilder builder(convertOp.getContext()); + SetVector innerSlice; + for (Value v : slice) { + if (!v.getDefiningOp()) { + LLVM_DEBUG( + { DBGS() << " Block arguments not supported. Got " << v << "\n"; }); + return; + } + auto loadOp = dyn_cast(v.getDefiningOp()); + // We expect the leaves of the slice to be Load or arith::Constant + // This could be generalised if necessary + if (!loadOp) { + auto op = v.getDefiningOp(); + if (isa(op) || noDataMovement(op)) { + innerSlice.insert(v); + continue; + } else { + LLVM_DEBUG({ + DBGS() << " Leaves must be Load or Constant. Got " << v << "\n"; + }); + return; + } + } + builder.setInsertionPointAfter(loadOp); + auto type = dyn_cast(loadOp.getType()); + if (!type) + continue; + auto newType = RankedTensorType::get(type.getShape(), type.getElementType(), + layout[loadOp]); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, loadOp.getResult()); + mapping.map(loadOp.getResult(), newConvertOp.getResult()); + } + + if (innerSlice.empty()) { + return; + } + + LLVM_DEBUG({ + DBGS() << " Hoisting " << convertOp << '\n'; + for (Value v : innerSlice) + DBGS() << " " << v << '\n'; + }); + + rewriteSlice(innerSlice, layout, convertOp, mapping); +} + // For convert left we try to hoist them above type extension to reduce the cost // of the convert. void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention + // DotOperand is hoisted by hoistDotOperand RankedTensorType targetType = convertOp.getType(); if (isa(targetType.getEncoding())) return; @@ -1337,6 +1450,10 @@ void hoistConvert(ModuleOp module) { layoutRemat = LayoutRematerialization(funcOp); layoutRemat.hoistConvertIntoConditionals(); layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertDotOperand(); + layoutRemat.cleanup(); }); } } // namespace diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 93e15722645c..4471725a6f69 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -3150,3 +3150,292 @@ tt.func public @reshape_slice_dot_enc(%arg0: tensor<4x16xi32, #blocked>) -> tens } } +#Cv2 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#Av2k1 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> +#Bv2k1 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> +#Av2k2 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> +#Bv2k2 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> +#Av2k4 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> +#Bv2k4 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> +#ALR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> +#BLR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: tt.func @push_elementwise +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> +// CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]] +// CHECK-SAME: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_elementwise( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> + %af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> + %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> + %dota = ttg.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + + +// CHECK: tt.func @succeeds_if_arg_is_not_convert_layout +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @succeeds_if_arg_is_not_convert_layout( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %dotai8 = ttg.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> + %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> + %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + +// CHECK: tt.func @push_inline_asm_op +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.elementwise_inline_asm {{.*}} %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_inline_asm_op( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %dotb: tensor<16x16xf16, #Bv2k4>, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %dotaf8 = tt.bitcast %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> + %dota = tt.elementwise_inline_asm "{ cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; }" {constraints = "=r,r", packed_element = 2 : i32, pure = true} %dotaf8 : tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> + %dota_cvt = ttg.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %newc = tt.dot %dota_cvt, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} +} + +// ----- + +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> + +// CHECK: tt.func @push_convert_both_operands +// CHECK-DAG: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> +// CHECK-DAG: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> +// CHECK-DAG: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +tt.func @push_convert_both_operands( + %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> + %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> + %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_mmav3_to_constant() + // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + tt.func @propagate_dot_op_mmav3_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> { + %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + } +} + +// ----- + +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> + +// CHECK: tt.func @update_kwidth_slice +// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> +// CHECK-DAG: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> +// CHECK-DAG: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +tt.func @update_kwidth_slice( + %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + %cst = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blockedB> + %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> + %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> + %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> + %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_to_constant() + // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + tt.func @propagate_dot_op_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> { + %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_to_constant_above_for() + // CHECK: arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + tt.func @propagate_dot_op_to_constant_above_for() -> tensor<32x128xf32, #mma> { + %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %loop:1 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst_1) -> (tensor<32x128xf32, #mma>) : i32 { + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %3 = tt.dot %2, %1, %arg0, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> + scf.yield %3 : tensor<32x128xf32, #mma> + } + tt.return %loop#0 : tensor<32x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // We currently don't propagate through block arguments on hoistDotOperand + // that being said, https://github.com/triton-lang/triton/pull/5350 + // allowed to lift DotOperand(opIdx=1), which might be alright + + // CHECK: tt.func @do_not_propagate_through_block_arguments() + // CHECK: %[[THROUGH_FOR_OP:.*]] = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[THROUGH_FOR_OP]], + tt.func @do_not_propagate_through_block_arguments() -> tensor<32x128xf32, #mma> { + %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %loop:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst, %arg1 = %cst_1) -> (tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma>) : i32 { + %0 = arith.addf %cst, %arg0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %3 = tt.dot %2, %1, %arg1, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> + scf.yield %0, %3 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma> + } + tt.return %loop#1 : tensor<32x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( + %pa: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice + // This checks that we propagate dot op layout given the following: + // initializer -> unsupported op -> initializer -> supported ops -> convert, + // where initializers can be constants or loads. + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + %offset = arith.constant dense<16> : tensor<16x1xi32, #blocked> + %broadcast = tt.broadcast %offset : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> + %pa2 = tt.addptr %pa, %broadcast : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> + %a = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> + %ae = arith.extf %a : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> + %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( + %pa1: tensor<16x1x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pa2: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice + // Confirm that both loads feed directly into a convert_layout. + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + // CHECK: %[[LOAD2:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD2]] + %a1 = tt.load %pa1 : tensor<16x1x!tt.ptr, #blocked> + %a2 = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> + %ab = tt.broadcast %a1 : tensor<16x1xf16, #blocked> -> tensor<16x16xf16, #blocked> + %aa = arith.addf %ab, %a2 : tensor<16x16xf16, #blocked> + %ae = arith.extf %aa : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> + %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> + } +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index c18b2d222ccc..97b87f207660 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -1,158 +1,9 @@ // RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -canonicalize | FileCheck %s -#Cv2 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#Av2k1 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> -#Bv2k1 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> -#Av2k2 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> -#Bv2k2 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> -#Av2k4 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> -#Bv2k4 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> -#ALR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> -#BLR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#BLC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> - -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - -// CHECK: tt.func @push_elementwise -// CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[BCVT:.*]] = ttg.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -// CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]] -// CHECK-SAME: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> -// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> -tt.func @push_elementwise( - %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ - %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> - %af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> - %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota = ttg.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> - %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> - tt.return %newc : tensor<16x16xf32, #Cv2> -} - - -// CHECK: tt.func @succeeds_if_arg_is_not_convert_layout -// CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] -// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] -// CHECK: %[[C:.*]] = tt.dot %[[AF16]] -// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> -tt.func @succeeds_if_arg_is_not_convert_layout( - %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ - %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %dotai8 = ttg.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> - %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> - %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> - %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> - tt.return %newc : tensor<16x16xf32, #Cv2> -} - -// CHECK: tt.func @push_inline_asm_op -// CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] -// CHECK: %[[AF16:.*]] = tt.elementwise_inline_asm {{.*}} %[[AF8E5]] -// CHECK: %[[C:.*]] = tt.dot %[[AF16]] -// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> -tt.func @push_inline_asm_op( - %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %dotb: tensor<16x16xf16, #Bv2k4>, - %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ - %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %dotaf8 = tt.bitcast %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> - %dota = tt.elementwise_inline_asm "{ cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; }" {constraints = "=r,r", packed_element = 2 : i32, pure = true} %dotaf8 : tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota_cvt = ttg.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> - %newc = tt.dot %dota_cvt, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> - tt.return %newc : tensor<16x16xf32, #Cv2> -} - -} - -// ----- #blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - -// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> - -// CHECK: tt.func @push_convert_both_operands -// CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> -// CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> -tt.func @push_convert_both_operands( - %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ - %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> - %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> - %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> - %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> - tt.return %r : tensor<16x16xf32, #mma> -} - -} - -// ----- - -#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - -// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> - -// CHECK: tt.func @update_kwidth_slice -// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> -// CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> -tt.func @update_kwidth_slice( - %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ - %cst = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blockedB> - %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> - %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> - %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> - %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> - %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> - %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> - tt.return %r : tensor<16x16xf32, #mma> -} - -} // ----- @@ -305,113 +156,6 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @propagate_dot_op_to_constant() - // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - tt.func @propagate_dot_op_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> { - %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @propagate_dot_op_mmav3_to_constant() - // CHECK: arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - tt.func @propagate_dot_op_mmav3_to_constant() -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> { - %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - tt.return %1 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @propagate_dot_op_to_constant_above_for() - // CHECK: arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - tt.func @propagate_dot_op_to_constant_above_for() -> tensor<32x128xf32, #mma> { - %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %c128_i32 = arith.constant 128 : i32 - %loop:1 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst_1) -> (tensor<32x128xf32, #mma>) : i32 { - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %3 = tt.dot %2, %1, %arg0, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> - scf.yield %3 : tensor<32x128xf32, #mma> - } - tt.return %loop#0 : tensor<32x128xf32, #mma> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK: tt.func @do_not_propagate_through_block_arguments() - // CHECK: %[[THROUGH_FOR_OP:.*]] = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[THROUGH_FOR_OP]], - tt.func @do_not_propagate_through_block_arguments() -> tensor<32x128xf32, #mma> { - %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %c128_i32 = arith.constant 128 : i32 - %loop:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst, %arg1 = %cst_1) -> (tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma>) : i32 { - %0 = arith.addf %cst, %arg0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %3 = tt.dot %2, %1, %arg1, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> - scf.yield %0, %3 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma> - } - tt.return %loop#1 : tensor<32x128xf32, #mma> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { - tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( - %pa: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, - %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, - %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ - // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice - // This checks that we propagate dot op layout given the following: - // initializer -> unsupported op -> initializer -> supported ops -> convert, - // where initializers can be constants or loads. - // CHECK: %[[LOAD1:.*]] = tt.load - // CHECK: ttg.convert_layout %[[LOAD1]] - %offset = arith.constant dense<16> : tensor<16x1xi32, #blocked> - %broadcast = tt.broadcast %offset : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> - %pa2 = tt.addptr %pa, %broadcast : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> - %a = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> - %ae = arith.extf %a : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> - %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> - tt.return %r : tensor<16x16xf32, #mma> - } -} - -// ----- - #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> #smem = #ttg.shared_memory #blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>