diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 5db9d01a262f..e80642a4dd97 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1,4 +1,5 @@ #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -269,6 +270,28 @@ class ConstantOpAxisInfoVisitor final } }; +class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(ub::PoisonOp op, + ArrayRef *> operands) override { + constexpr int64_t largePowerOf2 = int64_t(1) << 32; + // Poison values are never accessed, thus assume optimistic values. + if (auto shape = dyn_cast(op.getType())) { + unsigned rank = shape.getRank(); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2), + /*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2), + /*constancy=*/AxisInfo::DimVectorT(shape.getShape())); + } + + return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2}, + /*constancy=*/{1}); + } +}; + template class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { public: @@ -1012,6 +1035,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); visitors.append(); + visitors.append(); visitors.append(); visitors.append, AddSubOpAxisInfoVisitor, diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 06e75ee18d59..773c01e4a2a0 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -3,6 +3,7 @@ #include #include +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -97,16 +98,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( addDynamicallyLegalDialect([&](Operation *op) { - bool hasLegalRegions = true; - for (auto ®ion : op->getRegions()) { - hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); - } - if (hasLegalRegions && typeConverter.isLegal(op)) { - return true; - } - return false; - }); + scf::SCFDialect, ub::UBDialect>( + [&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); // We have requirements for the data layouts addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 54917a23705c..4b0e8b111216 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -859,6 +860,7 @@ class ConvertTritonToTritonGPU // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); populateCFPatterns(typeConverter, patterns); + patterns.insert>(typeConverter, context); auto inti = llvm::APSInt(32, false); auto i32_ty = IntegerType::get(mod->getContext(), 32); diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index a74d274d386e..a84f9ab77be0 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -1,7 +1,11 @@ #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "llvm/Support/Debug.h" #include @@ -16,6 +20,13 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUFUSENESTEDLOOPS #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +// This attribute is set by the front-end to control whether fusion is on. +static constexpr llvm::StringLiteral kFlattenAttr = "tt.flatten"; +// This attribute indicates the inner loop length has been speculated. +static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute"; +// This attribute is just used for testing the pass. +static constexpr llvm::StringLiteral kAlwaysFuseAttrName = "ttg.always-fuse"; + namespace { struct FuseNestedLoopsPass : public impl::TritonGPUFuseNestedLoopsBase { @@ -293,27 +304,74 @@ static unsigned getIntTypeWidth(Type type) { } // Generate IR to compute the number of iterations of a loop. -static Value computeNumIters(OpBuilder &b, scf::ForOp loop) { +static Value computeNumIters(ImplicitLocOpBuilder &b, scf::ForOp loop) { // len(range(lb, ub, step)) = ceildiv(ub - lb, step) // This works even if step is negative. - Location loc = loop.getLoc(); Value diff = - b.create(loc, loop.getUpperBound(), loop.getLowerBound()); + b.create(loop.getUpperBound(), loop.getLowerBound()); // Let someone else prove it can be unsigned. - return b.create(loc, diff, loop.getStep()); + return b.create(diff, loop.getStep()); } // Cast an integer or index value to an integer or index `type`, if necessary. -static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, +static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value, Type type) { if (value.getType() == type) return value; if (isa(value.getType()) || isa(type)) - return b.create(loc, type, value); + return b.create(type, value); if (cast(value.getType()).getWidth() > cast(type).getWidth()) - return b.create(loc, type, value); - return b.create(loc, type, value); + return b.create(type, value); + return b.create(type, value); +} + +// To model an "undef" value, i.e. a value that is known to never be read on +// live code paths, create a zero-valued constant where possible, otherwise use +// a poison value. PTXAS appears to generate better code with zeros compared to +// poison values. +static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) { + Type elTy = getElementTypeOrSelf(type); + if (!elTy.isIntOrIndexOrFloat() || + (!isa(type) && type != elTy)) + return b.create(type); + + TypedAttr attr = isa(elTy) ? TypedAttr(b.getFloatAttr(elTy, 0)) + : b.getIntegerAttr(elTy, 0); + if (auto tensor = dyn_cast(type)) + attr = SplatElementsAttr::get(tensor, attr); + return b.create(attr); +} + +static scf::YieldOp getYield(Region &body) { + return cast(body.front().back()); +} + +static scf::IfOp eraseIfResults(ImplicitLocOpBuilder &b, scf::IfOp ifOp, + llvm::BitVector indices, + SmallVector replaceWith) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(ifOp); + while (indices.size() < ifOp.getNumResults()) + indices.push_back(false); + + getYield(ifOp.getThenRegion())->eraseOperands(indices); + getYield(ifOp.getElseRegion())->eraseOperands(indices); + + TypeRange newTypes = getYield(ifOp.getThenRegion()).getOperandTypes(); + auto newIf = b.create(newTypes, ifOp.getCondition()); + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + + SmallVector replacements; + auto replIt = replaceWith.begin(); + auto resIt = newIf->result_begin(); + for (unsigned i : llvm::seq(ifOp.getNumResults())) + replacements.push_back(indices[i] ? *replIt++ : *resIt++); + assert(ValueRange(replacements).getTypes() == ifOp.getResultTypes()); + ifOp.replaceAllUsesWith(replacements); + ifOp.erase(); + return newIf; } // Given a one level loop nest in the form @@ -342,11 +400,12 @@ static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, // total_iters = len_i * inner_len // // T = -1 -// i = lbi +// i = lbi - stepi // for _ in range(total_iters): -// T = (T + 1) % inner_len +// T = 0 if T == (inner_len - 1) else T + 1 // // if T == 0: +// i += stepi // prologue0(i) // j0 = lbj0 // if T >= 0 and T < len_j0: @@ -382,7 +441,6 @@ static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, // // if T == max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - (N + 1): // epilogue(i) -// i += stepi // // This routine can be applied recursively on a loop nest tree, leaf-to-root, to // flatten the loop nest into a single loop. However, this routine only fuses @@ -441,7 +499,9 @@ static Value castIntIfNecessary(OpBuilder &b, Location loc, Value value, // } // // Note: the induction variables will be initialized to their lower bound to -// avoid underflow in lbjk - stepjk. +// avoid underflow in lbjk - stepjk, with the exception of the outer loop +// induction variable, which needs to be incremented inside the prologue to +// avoid a dependency on the epilogue. This helps the scheduler behave. // // Any inputs and outputs of the loop bodies would also need to be handled // similarly: initialized as undef if appropriate and carried through the fused @@ -494,7 +554,8 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { unsigned intTyWidth = getIntTypeWidth(outer.getInductionVar().getType()); // Generate the computations of the fused loop bounds. - OpBuilder b(outer); + Location loc = outer.getLoc(); + ImplicitLocOpBuilder b(loc, outer); Value lenOuter = computeNumIters(b, outer); SmallVector lenInners; for (scf::ForOp loop : innerLoops) { @@ -503,12 +564,10 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { intTyWidth = std::max(intTyWidth, getIntTypeWidth(lenInner.getType())); lenInners.push_back(lenInner); } - intTyWidth = std::min(64u, intTyWidth * 2); auto intTy = b.getIntegerType(intTyWidth); - Location loc = outer.getLoc(); auto intTyCst = [&](int64_t v) { - return b.create(loc, IntegerAttr::get(intTy, v)); + return b.create(IntegerAttr::get(intTy, v)); }; // inner_len = max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - N @@ -518,16 +577,16 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector partialInnerSums; partialInnerSums.push_back(innerLen); for (Value lenInner : lenInners) { - lenInner = castIntIfNecessary(b, loc, lenInner, intTy); - lenInner = b.create(loc, intTyCst(1), lenInner); - innerLen = b.create(loc, innerLen, lenInner); + lenInner = castIntIfNecessary(b, lenInner, intTy); + lenInner = b.create(intTyCst(1), lenInner); + innerLen = b.create(innerLen, lenInner); partialInnerSums.push_back(innerLen); } - innerLen = b.create(loc, innerLen, intTyCst(N)); + innerLen = b.create(innerLen, intTyCst(N)); // total_iters = len_i * inner_len - Value totalIters = b.create( - loc, castIntIfNecessary(b, loc, lenOuter, intTy), innerLen); + Value totalIters = + b.create(castIntIfNecessary(b, lenOuter, intTy), innerLen); // The outputs of the prologue, each epilogue, and all inner loop bodies need // to carried through the fused loop. @@ -558,8 +617,9 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { // T = -1 fusedInits.push_back(intTyCst(-1)); - // i = lbi - fusedInits.push_back(outer.getLowerBound()); + // i = lbi - stepi + fusedInits.push_back( + b.create(outer.getLowerBound(), outer.getStep())); unsigned outerArgsStartIdx = fusedInits.size(); llvm::append_range(fusedInits, outer.getInits()); @@ -568,22 +628,22 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { unsigned ivarStartIdx = fusedInits.size(); for (scf::ForOp loop : innerLoops) { fusedInits.push_back( - b.create(loc, loop.getInductionVar().getType())); + createPoisonOrZero(b, loop.getInductionVar().getType())); } unsigned innerOutsStartIdx = fusedInits.size(); for (scf::ForOp loop : innerLoops) { for (Type resultType : loop.getResultTypes()) - fusedInits.push_back(b.create(loc, resultType)); + fusedInits.push_back(createPoisonOrZero(b, resultType)); } unsigned logueOutsStartIdx = fusedInits.size(); - for (Logue &logue : logues) { + for (Logue &logue : llvm::drop_end(logues)) { for (Type outputType : logue.getOutputTypes()) - fusedInits.push_back(b.create(loc, outputType)); + fusedInits.push_back(createPoisonOrZero(b, outputType)); } // for _ in range(total_iters): - auto fused = b.create(loc, intTyCst(0), totalIters, intTyCst(1), - fusedInits); + auto fused = + b.create(intTyCst(0), totalIters, intTyCst(1), fusedInits); // Replace the outer loop args with the args in the fused loop args. for (auto [arg, fusedArg] : llvm::zip(outer.getRegionIterArgs(), @@ -592,14 +652,17 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { } b.setInsertionPointToStart(fused.getBody()); - // T = (T + 1) % inner_len + // T = 0 if T == (inner_len - 1) else T + 1 Value T = fused.getRegionIterArg(0); - T = b.create(loc, T, intTyCst(1)); - T = b.create(loc, T, innerLen); + Value nextT = b.create(T, intTyCst(1)); + Value rollover = + b.create(arith::CmpIPredicate::eq, T, + b.create(innerLen, intTyCst(1))); + T = b.create(rollover, intTyCst(0), nextT); - // Replace uses of `i` within the fused loop. - Value i = fused.getRegionIterArg(1); - outer.getInductionVar().replaceAllUsesWith(i); + // `i` is computed inside the first prologue. + Value curI = fused.getRegionIterArg(1); + Value i; assert(partialInnerSums.size() == N + 2); ArrayRef ivars = fused.getRegionIterArgs().slice(ivarStartIdx); @@ -607,15 +670,16 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { ValueRange(fused.getRegionIterArgs()).begin() + innerOutsStartIdx; auto logueOutsIt = ValueRange(fused.getRegionIterArgs()).begin() + logueOutsStartIdx; - SmallVector logueIfs, bodyIfs; + SmallVector prologueIfs, bodyIfs; for (unsigned k = 0; k <= N; ++k) { // if T == max(1, len_j0) + ... max(1, len_jk-1) - k + // [[if k == 0]] i += stepi // prologuek(i) // jk = lbjk Value innerStartT = - b.create(loc, partialInnerSums[k], intTyCst(k)); + b.create(partialInnerSums[k], intTyCst(k)); Value prologueCond = - b.create(loc, arith::CmpIPredicate::eq, T, innerStartT); + b.create(arith::CmpIPredicate::eq, T, innerStartT); // The `scf.if` outputs will be `jk` and the outputs of prologuek. We also // have to initialize the inner loop iter args. @@ -625,20 +689,32 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector prologueOutTypes{inner.getInductionVar().getType()}; llvm::append_range(prologueOutTypes, prologue.getOutputTypes()); llvm::append_range(prologueOutTypes, inner.getInits().getTypes()); - auto prologueIf = b.create(loc, prologueOutTypes, prologueCond); - logueIfs.push_back(prologueIf); + if (k == 0) + prologueOutTypes.push_back(curI.getType()); + auto prologueIf = b.create(prologueOutTypes, prologueCond); + prologueIfs.push_back(prologueIf); // Splice prologuek into the `then` region. Block *thenBlock = b.createBlock(&prologueIf.getThenRegion()); prologue.moveBefore(thenBlock, thenBlock->end()); + if (k == 0) { + // Increment `i` and replace its uses inside the prologue. + b.setInsertionPointToStart(thenBlock); + i = b.create(curI, outer.getStep()); + mlir::replaceAllUsesInRegionWith(outer.getInductionVar(), i, + prologueIf.getThenRegion()); + } + // Yield the initialized jk, the prologue outputs, and the initial values of // the inner loop. b.setInsertionPointToEnd(thenBlock); SmallVector thenOuts{inner.getLowerBound()}; llvm::append_range(thenOuts, prologue.getOutputs()); llvm::append_range(thenOuts, inner.getInits()); - b.create(loc, thenOuts); + if (k == 0) + thenOuts.push_back(i); + b.create(thenOuts); // In the `else` region, just yield the last values of jk, the outputs, and // the iter args. @@ -648,8 +724,10 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector elseOuts{lastJk}; elseOuts.append(logueOutsIt, logueOutsIt + numOuts); elseOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); + if (k == 0) + elseOuts.push_back(curI); logueOutsIt += numOuts; - b.create(loc, elseOuts); + b.create(elseOuts); // The results of the `scf.if` become the values of jk and the prologue // outputs for the rest of the fused loop. @@ -662,6 +740,11 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { for (auto [init, iterArg] : llvm::zip(prologueInits, inner.getRegionIterArgs())) iterArg.replaceAllUsesWith(init); + // Replace uses of `i` elsewhere with the prologue result. + if (k == 0) { + i = prologueIf.getResults().back(); + outer.getInductionVar().replaceAllUsesWith(i); + } // if T >= max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k // and T < max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k + @@ -670,26 +753,24 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { // jk += stepjk b.setInsertionPointAfter(prologueIf); Value innerEndT = b.create( - loc, innerStartT, castIntIfNecessary(b, loc, lenInners[k], intTy)); + innerStartT, castIntIfNecessary(b, lenInners[k], intTy)); Value ge = - b.create(loc, arith::CmpIPredicate::sge, T, innerStartT); - Value lt = - b.create(loc, arith::CmpIPredicate::slt, T, innerEndT); - Value bodyCond = b.create(loc, ge, lt); + b.create(arith::CmpIPredicate::sge, T, innerStartT); + Value lt = b.create(arith::CmpIPredicate::slt, T, innerEndT); + Value bodyCond = b.create(ge, lt); // The outputs will be the outputs of the inner loop body and the next jk. SmallVector bodyOutTypes{jk.getType()}; llvm::append_range(bodyOutTypes, inner->getResultTypes()); - auto bodyIf = b.create(loc, bodyOutTypes, bodyCond); + auto bodyIf = b.create(bodyOutTypes, bodyCond); bodyIfs.push_back(bodyIf); // Splice bodyk into the `then` region. inner.getBody()->eraseArguments([](Value arg) { return true; }); bodyIf.getThenRegion().takeBody(inner.getBodyRegion()); - auto yield = - cast(bodyIf.getThenRegion().front().getTerminator()); + auto yield = getYield(bodyIf.getThenRegion()); b.setInsertionPoint(yield); - Value nextJk = b.create(loc, jk, inner.getStep()); + Value nextJk = b.create(jk, inner.getStep()); yield->insertOperands(0, nextJk); // The `else` region just forwards the values. @@ -697,69 +778,126 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { SmallVector bodyForwardedOuts{jk}; bodyForwardedOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); bodyOutsIt += inner->getNumResults(); - b.create(loc, bodyForwardedOuts); + b.create(bodyForwardedOuts); // Now we can replace the results of the inner loop with the outputs of the // body if. inner.replaceAllUsesWith( bodyIf.getResults().slice(1, inner.getNumResults())); + // If the inner loop must execute, then its body does not have to be wrapped + // in a conditional. + if (inner->hasAttr(kMustExecuteAttrName)) { + b.setInsertionPoint(bodyIf); + bodyIf.getConditionMutable().assign( + b.create(b.getBoolAttr(true))); + } + // Move the insertion point for the next iteration. b.setInsertionPointAfter(bodyIf); } // if T == len_j0 + len_j1 + ... + len_jN - N - 1: // epilogue(i) - // i += stepi Logue &epilogue = logues.back(); - auto epilogueCond = b.create( - loc, arith::CmpIPredicate::eq, T, - b.create(loc, innerLen, intTyCst(1))); - SmallVector epilogueOutTypes{i.getType()}; - llvm::append_range(epilogueOutTypes, epilogue.getOutputTypes()); - auto epilogueIf = b.create(loc, epilogueOutTypes, epilogueCond); - logueIfs.push_back(epilogueIf); + auto epilogueCond = + b.create(arith::CmpIPredicate::eq, T, + b.create(innerLen, intTyCst(1))); + auto epilogueIf = + b.create(outer.getYieldedValues().getTypes(), epilogueCond); Block *thenBlock = b.createBlock(&epilogueIf.getThenRegion()); epilogue.moveBefore(thenBlock, thenBlock->end()); b.setInsertionPointToEnd(thenBlock); - Value nextI = b.create(loc, i, outer.getStep()); - SmallVector thenOuts{nextI}; - llvm::append_range(thenOuts, epilogue.getOutputs()); - b.create(loc, thenOuts); + b.create(outer.getYieldedValues()); b.createBlock(&epilogueIf.getElseRegion()); - SmallVector elseOuts{i}; - elseOuts.append(logueOutsIt, logueOutsIt + epilogue.getNumOutputs()); - b.create(loc, elseOuts); - epilogue.replaceAllUsesWith( - epilogueIf.getResults().slice(1, epilogue.getNumOutputs()), - epilogueIf.getThenRegion()); + b.create(fused.getRegionIterArgs().slice( + outerArgsStartIdx, outer.getNumRegionIterArgs())); // Finally, create the yield of the fused loop. - SmallVector outerOuts{T, /*i=*/epilogueIf.getResult(0)}; - llvm::append_range(outerOuts, outer.getYieldedValues()); + SmallVector outerOuts{T, i}; + llvm::append_range(outerOuts, epilogueIf.getResults()); for (scf::IfOp bodyIf : bodyIfs) outerOuts.push_back(/*jk=*/bodyIf.getResult(0)); for (auto [bodyIf, loop] : llvm::zip(bodyIfs, innerLoops)) { llvm::append_range(outerOuts, bodyIf.getResults().slice(1, loop.getNumResults())); - loop.erase(); } - for (auto [logueIf, logue] : llvm::zip(logueIfs, logues)) { + for (auto [logueIf, logue] : llvm::zip(prologueIfs, llvm::drop_end(logues))) { llvm::append_range(outerOuts, logueIf.getResults().slice(1, logue.getNumOutputs())); } b.setInsertionPointToEnd(fused.getBody()); - b.create(loc, outerOuts); + auto outerYield = b.create(outerOuts); outer.replaceAllUsesWith( fused.getResults().slice(outerArgsStartIdx, outer.getNumResults())); - outer.erase(); - // Update the parent's loop to the fused loop. + // Reduce dependencies across inner loops by hoisting the initialization of + // inner loop iter args to the outer loop when possible, and then placing the + // reset of these values in the epilogue. + auto fusedInitsIt = fused.getInitsMutable().begin() + innerOutsStartIdx; + auto fusedArgsIt = fused.getRegionIterArgs().begin() + innerOutsStartIdx; + auto fusedYieldIt = getYield(fused.getBodyRegion())->getOpOperands().begin() + + innerOutsStartIdx; + SmallVector yieldsToUpdate; + SmallVector reset, forwarded; + for (auto [loop, ifOp, bodyIf, prologue] : + llvm::zip(innerLoops, prologueIfs, bodyIfs, logues)) { + unsigned numResults = loop.getNumResults(); + unsigned prologueSkip = 1 + prologue.getNumOutputs(); + + llvm::BitVector removeIndices(prologueSkip + numResults); + SmallVector replaceWith; + for (auto [i, init] : llvm::enumerate(loop.getInits())) { + if (init.getParentRegion() == &fused.getBodyRegion()) + continue; + // Initialize this in the outer loop. + fusedInitsIt[i].assign(init); + replaceWith.push_back(fusedArgsIt[i]); + removeIndices.set(prologueSkip + i); + yieldsToUpdate.push_back(&fusedYieldIt[i]); + forwarded.push_back(bodyIf.getResult(1 + i)); + reset.push_back(init); + } + // Remove the initializers in the corresponding prologue. + eraseIfResults(b, ifOp, removeIndices, replaceWith); + + fusedInitsIt += numResults; + fusedArgsIt += numResults; + fusedYieldIt += numResults; + } + if (!yieldsToUpdate.empty()) { + MutableOperandRange(getYield(epilogueIf.getThenRegion())).append(reset); + MutableOperandRange(getYield(epilogueIf.getElseRegion())).append(forwarded); + b.setInsertionPoint(epilogueIf); + TypeRange newTypes = getYield(epilogueIf.getThenRegion()).getOperandTypes(); + auto newIf = b.create(newTypes, epilogueIf.getCondition()); + newIf.getThenRegion().takeBody(epilogueIf.getThenRegion()); + newIf.getElseRegion().takeBody(epilogueIf.getElseRegion()); + epilogueIf.replaceAllUsesWith( + newIf.getResults().take_front(epilogueIf.getNumResults())); + ResultRange newResults = + newIf.getResults().drop_front(epilogueIf.getNumResults()); + for (auto [i, yieldOperand] : llvm::enumerate(yieldsToUpdate)) + yieldOperand->set(newResults[i]); + epilogueIf.erase(); + } + + // Update the parent's loop to the fused loop. Set the new stage count to the + // max stage count of the inner loops. + int numStages = 1; + for (scf::ForOp loop : innerLoops) { + if (auto stageAttr = loop->getAttrOfType(kNumStagesAttrName)) + numStages = std::max(numStages, stageAttr.getInt()); + loop.erase(); + } + outer.erase(); parent->loop = fused; + if (numStages > 1) + fused->setAttr(kNumStagesAttrName, b.getI32IntegerAttr(numStages)); } //===----------------------------------------------------------------------===// @@ -778,14 +916,148 @@ static void flattenLoopNest(LoopNestNode *node, mlir::DominanceInfo &domInfo) { // Pass Implementation //===----------------------------------------------------------------------===// +// Fuse simple loop nests with a single outer and inner loop, and where the +// inner loop has a `tt.dot` operation. +static bool shouldFuse(const LoopNest &nest) { + if (nest.root->loop->hasAttr(kAlwaysFuseAttrName)) + return true; + + // Only fuse simple loop nests. + return nest.nodes.size() == 2 && nest.root->children.size() == 1 && + nest.root->loop->hasAttr(kFlattenAttr); +} + +// This function identifies a subgraph of cheap ops that can be sunk between two +// regions in the loop nest and moves them, reducing their liveranges. +static void sinkOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore, + llvm::iterator_range prologue, + function_ref inSinkRegion) { + llvm::SetVector sunkOps; + auto canBeSunk = [&](Operation &op) -> std::pair { + if (!isPure(&op) || isa(op)) + return {false, false}; + // An op can be sunk if all its users are inside the inner loop or are + // marked for sinking. + bool isRoot = true; + for (Operation *user : op.getUsers()) { + if (inSinkRegion(user)) + continue; + isRoot = false; + if (sunkOps.contains(user)) + continue; + return {false, false}; + } + return {true, isRoot}; + }; + + // Find the subgraph of operations that can be sunk. + SmallVector roots; + for (Operation &op : llvm::reverse(prologue)) { + auto [canSink, isRoot] = canBeSunk(op); + if (canSink) + sunkOps.insert(&op); + if (isRoot) + roots.push_back(&op); + } + if (sunkOps.empty()) + return; + + sunkOps = topologicalSort(sunkOps); + for (Operation *op : sunkOps) + op->moveBefore(sinkBlock, sinkBefore); +} + +// Sink ops from the prologue into the epilogue when possible. +static void optimizeEpilogueDependencies(scf::ForOp outerLoop, + scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { + auto inEpilogue = [&](Operation *op) { + return domInfo.properlyDominates(innerLoop, op, /*enclosingOpOk=*/false); + }; + Region &limit = outerLoop.getBodyRegion(); + sinkOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()), + {outerLoop.getBody()->begin(), innerLoop->getIterator()}, inEpilogue); +} + +// Speculate the length of the inner loop such that the loop is known to execute +// at least once. This way, the inner loop body does not have to be placed +// inside a conditional in the fused loop, which interacts better with the +// pipeliner. +static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop, + scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { + // The inner loop bounds must be outer-loop invariant to speculate from + // outside the loop nest. + Location loc = innerLoop.getLoc(); + llvm::SetVector toHoist; + if (!isOuterLoopInvariant(domInfo, outerLoop, + {innerLoop.getLowerBound(), + innerLoop.getUpperBound(), innerLoop.getStep()}, + toHoist)) + return failure(); + + // Hoist the inner loop bounds computations if necessary. + toHoist = topologicalSort(toHoist); + for (Operation *op : toHoist) + op->moveBefore(outerLoop); + + // Mark the inner loop. + ImplicitLocOpBuilder b(loc, outerLoop); + innerLoop->setAttr(kMustExecuteAttrName, b.getUnitAttr()); + + // Speculate on whether the length of the inner loop is zero. + Value lenInner = computeNumIters(b, innerLoop); + auto zeroAttr = IntegerAttr::get(lenInner.getType(), 0); + Value innerLoopEmpty = + b.create(arith::CmpIPredicate::eq, lenInner, + b.create(zeroAttr)); + auto ifOp = b.create(outerLoop.getResultTypes(), innerLoopEmpty); + + // In the `then` branch, the inner loop does not execute. Clone the loop nest + // into it and remove the inner loop. + mlir::IRMapping map; + b.createBlock(&ifOp.getThenRegion()); + auto newLoop = cast(b.clone(*outerLoop, map)); + b.create(newLoop.getResults()); + auto newInnerLoop = cast(map.lookup(innerLoop)); + newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits()); + newInnerLoop.erase(); + + // Move the loop nest into the `else` branch. + outerLoop.replaceAllUsesWith(ifOp.getResults()); + Block *block = b.createBlock(&ifOp.getElseRegion()); + outerLoop->remove(); + b.insert(outerLoop); + b.create(outerLoop.getResults()); + + return success(); +} + +static LogicalResult preprocessLoopNest(const LoopNest &nest, + mlir::DominanceInfo &domInfo) { + assert(nest.nodes.size() == 2 && nest.root->children.size() == 1); + + scf::ForOp &outerLoop = nest.root->loop; + scf::ForOp &innerLoop = nest.root->children.front()->loop; + + optimizeEpilogueDependencies(outerLoop, innerLoop, domInfo); + return speculateInnerLoopLength(outerLoop, innerLoop, domInfo); +} + void FuseNestedLoopsPass::runOnOperation() { auto &domInfo = getAnalysis(); for (auto func : getOperation().getOps()) { SmallVector nests; findLoopNests(func, nests); - for (LoopNest &nest : nests) + for (LoopNest &nest : nests) { + if (!shouldFuse(nest)) + continue; + if (!nest.root->loop->hasAttr(kAlwaysFuseAttrName) && + failed(preprocessLoopNest(nest, domInfo))) + continue; flattenLoopNest(nest.root, domInfo); + } } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index e5ff5ba340fd..3ab5366dd101 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -258,7 +258,11 @@ bool LoopPipelinerInternal::verifySchedule() { continue; int64_t producerCycle = it->second; if (consumerCycle < producerCycle - numCylesPerIter * distance) { - consumer->emitError("operation scheduled before its operands"); + InFlightDiagnostic diag = + consumer->emitError("operation scheduled before its operands"); + diag.attachNote(producer->getLoc()) + .append("operand defined here: ") + .appendOp(*producer, OpPrintingFlags().printGenericOpForm()); return false; } } @@ -291,7 +295,19 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { setValueMapping(arg, operand.get(), 0); } + + // If the incoming value to an iter arg from the loop yield is defined outside + // the loop, then that means the iter arg takes that value for all stages + // after the first stage. auto yield = cast(forOp.getBody()->getTerminator()); + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), yield->getOpOperands())) { + if (forOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) + continue; + for (int64_t i = 1; i < maxStage; ++i) + setValueMapping(arg, operand.get(), i); + } + Location loc = forOp.getLoc(); SmallVector predicates(maxStage); for (int64_t i = 0; i < maxStage; i++) { diff --git a/python/src/ir.cc b/python/src/ir.cc index b5411dd4281d..14fec22e5889 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -458,6 +458,7 @@ void init_triton_ir(py::module &&m) { py::class_(m, "attribute", py::module_local()); py::class_(m, "integer_attr", py::module_local()); py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "unit_attr", py::module_local()); // Ops py::class_(m, "OpState", py::module_local()) @@ -750,6 +751,9 @@ void init_triton_ir(py::module &&m) { self.restoreInsertionPoint(pt); }) // Attr + .def( + "get_unit_attr", + [](TritonOpBuilder &self) { return self.getBuilder().getUnitAttr(); }) .def("get_bool_attr", [](TritonOpBuilder &self, bool value) { return self.getBuilder().getBoolAttr(value); @@ -1778,6 +1782,13 @@ void init_triton_ir(py::module &&m) { printingFlags); } }) + .def("get_pipeline_str", + [](PassManager &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.printAsTextualPipeline(os); + return str; + }) .def("run", [](PassManager &self, ModuleOp &mod) { // TODO: maybe dump module to file and print error for better // diagnostics diff --git a/python/src/passes.cc b/python/src/passes.cc index b0efc3cb884b..619ece2e3455 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -71,6 +71,7 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUCombineTensorSelectAndIf); ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_WRAPPER_0("add_fuse_nested_loops", createTritonGPUFuseNestedLoops); ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling", createTritonGPULoopScheduling, int); ADD_PASS_WRAPPER_0("add_coalesce_async_copy", diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8d8c33c5d456..b38f88eef62e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6569,6 +6569,21 @@ def test_tl_range(device): assert 'cp.async.wait_group 6' in ptx +def test_tl_range_fuse(): + if is_hip(): + pytest.skip("loop fusion is not enabled on AMD") + + @triton.jit + def kernel(ub): + for i in tl.range(0, ub, flatten=True): + for j in tl.range(0, ub): + print("i", i) + + compiled_kernel = kernel.warmup(10, grid=(1, )) + assert "tt.flatten" in compiled_kernel.asm["ttir"] + assert compiled_kernel.asm["ttgir"].count("scf.for") == 1 + + @triton.jit(noinline=True) def maxnreg_noinline1(X): tl.store(X, 0) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index ae7a0c92e22b..02fcdac45c25 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -998,6 +998,7 @@ def visit_For(self, node): return num_stages = None loop_unroll_factor = None + flatten = None if IteratorClass is language.range: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments @@ -1008,6 +1009,7 @@ def visit_For(self, node): step = iterator.step num_stages = iterator.num_stages loop_unroll_factor = iterator.loop_unroll_factor + flatten = iterator.flatten elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -1082,6 +1084,8 @@ def visit_For(self, node): for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + if flatten: + for_op.set_attr("tt.flatten", self.builder.get_unit_attr()) self.scf_stack.append(node) for_op_body = for_op.get_body(0) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 32c4af9eadb7..1fab285e2815 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2865,9 +2865,12 @@ def kernel(...): :param loop_unroll_factor: Tells the Triton IR level loop unroller how many times to unroll a for loop that this range is used with. Less than 2 for this value implies no unrolling. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. """ - def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, flatten=None): if step is None: self.step = constexpr(1) else: @@ -2880,6 +2883,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact self.end = arg2 self.num_stages = num_stages self.loop_unroll_factor = loop_unroll_factor + self.flatten = flatten def __iter__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index d4fc294e11d7..c64ad0ae816d 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -222,14 +222,13 @@ def matmul(a, b): @triton.jit -def _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): - tile_id += NUM_SMS +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m - return tile_id, pid_m, pid_n + return pid_m, pid_n @triton.autotune( @@ -254,58 +253,45 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being # used in both the prologue and epilogue, so we duplicate the counters as a work-around. - tile_id = start_pid - NUM_SMS tile_id_c = start_pid - NUM_SMS - ki = -1 offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id, pid_m, pid_n = _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) - start_m = pid_m * BLOCK_SIZE_M - start_n = pid_n * BLOCK_SIZE_N - offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) - offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) - offs_am = tl.where(offs_am < M, offs_am, 0) - offs_bn = tl.where(offs_bn < N, offs_bn, 0) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) - offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, b, accumulator) - - if ki == k_tiles - 1: - tile_id_c, pid_m, pid_n = _compute_tile_and_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, - NUM_SMS) - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if (c_ptr.dtype.element_ty == tl.float8e4nv): - c = accumulator.to(tl.float8e4nv) - else: - c = accumulator.to(tl.float16) - tl.store(c_ptrs, c, mask=c_mask) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if (c_ptr.dtype.element_ty == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) def matmul_persistent(a, b): @@ -364,59 +350,41 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS tile_id_c = start_pid - NUM_SMS - - ki = -1 - - offs_am = 0 - offs_bn = 0 - num_pid_in_group = GROUP_SIZE_M * num_pid_n - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id, pid_m, pid_n = _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) - offs_am = pid_m * BLOCK_SIZE_M - offs_bn = pid_n * BLOCK_SIZE_N - - offs_k = ki * BLOCK_SIZE_K - - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) - accumulator = tl.dot(a, b.T, accumulator) - - if ki == k_tiles - 1: - tile_id_c, pid_m, pid_n = _compute_tile_and_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, - NUM_SMS) - - offs_am_c = pid_m * BLOCK_SIZE_M - offs_bn_c = pid_n * BLOCK_SIZE_N - - # Epilogue subtiling is a technique to break our computation and stores into multiple pieces - # By subtiling we can reduce shared memory consumption by the epilogue and instead use that - # memory to increase our stage count. - # In this case we partition the accumulator into 2 BLOCK_SIZE_M x BLOCK_SIZE_N // 2 tensors - if EPILOGUE_SUBTILE: - acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) - acc = tl.permute(acc, (0, 2, 1)) - acc0, acc1 = tl.split(acc) - c0 = acc0.to(dtype) - tl._experimental_descriptor_store(c_desc_ptr, c0, [offs_am_c, offs_bn_c]) - c1 = acc1.to(dtype) - tl._experimental_descriptor_store(c_desc_ptr, c1, [offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) - else: - accumulator = accumulator.to(dtype) - tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am_c, offs_bn_c]) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + # Epilogue subtiling is a technique to break our computation and stores into multiple pieces + # By subtiling we can reduce shared memory consumption by the epilogue and instead use that + # memory to increase our stage count. + # In this case we partition the accumulator into 2 BLOCK_SIZE_M x BLOCK_SIZE_N // 2 tensors + if EPILOGUE_SUBTILE: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, c0, [offs_am_c, offs_bn_c]) + c1 = acc1.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, c1, [offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + else: + accumulator = accumulator.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am_c, offs_bn_c]) def matmul_tma_persistent(a, b): @@ -532,54 +500,39 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2], ) - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS + # tile_id_c is used in the epilogue to break the dependency between + # the prologue and the epilogue tile_id_c = start_pid - NUM_SMS - ki = -1 - - offs_am = 0 - offs_bn = 0 - num_pid_in_group = GROUP_SIZE_M * num_pid_n - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id, pid_m, pid_n = _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) - - offs_am = pid_m * BLOCK_SIZE_M - offs_bn = pid_n * BLOCK_SIZE_N - - offs_k = ki * BLOCK_SIZE_K - - a = a_desc.load([offs_am, offs_k]) - b = b_desc.load([offs_bn, offs_k]) - accumulator = tl.dot(a, b.T, accumulator) - - if ki == k_tiles - 1: - tile_id_c, pid_m, pid_n = _compute_tile_and_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, - NUM_SMS) - offs_cm = pid_m * BLOCK_SIZE_M - offs_cn = pid_n * BLOCK_SIZE_N - - if EPILOGUE_SUBTILE: - acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) - acc = tl.permute(acc, (0, 2, 1)) - acc0, acc1 = tl.split(acc) - c0 = acc0.to(dtype) - c_desc.store([offs_cm, offs_cn], c0) - c1 = acc1.to(dtype) - c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1) - else: - c = accumulator.to(dtype) - c_desc.store([offs_cm, offs_cn], c) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_cm = pid_m * BLOCK_SIZE_M + offs_cn = pid_n * BLOCK_SIZE_N + + if EPILOGUE_SUBTILE: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c_desc.store([offs_cm, offs_cn], c0) + c1 = acc1.to(dtype) + c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1) + else: + c = accumulator.to(dtype) + c_desc.store([offs_cm, offs_cn], c) def matmul_descriptor_persistent(a, b): diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index bab044f46b5c..62763e4f056c 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -143,3 +143,12 @@ tt.func @scatter4_layout(%arg0: !tt.tensordesc>, %arg1: i32, % tt.experimental_descriptor_scatter %arg0[%cst, %arg1], %1 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xf32> tt.return } + +// ----- + +// CHECK-LABEL: @ub_poison +tt.func @ub_poison() { + // CHECK-NEXT: ub.poison : tensor<128x64xf16, #blocked> + %0 = ub.poison : tensor<128x64xf16> + tt.return +} diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 25e136514b01..44eb3d47e293 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -160,3 +160,42 @@ module { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + +// CHECK: [[COALESCED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK: @coalesce_poison +tt.func @coalesce_poison(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<128xi32, #blocked1> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %4 = ttg.convert_layout %3 : tensor<128x1xi32, #blocked2> -> tensor<128x1xi32, #blocked3> + %5 = tt.broadcast %4 {axis = 1 : i32} : tensor<128x1xi32, #blocked3> -> tensor<128x64xi32, #blocked3> + %6 = ttg.convert_layout %5 : tensor<128x64xi32, #blocked3> -> tensor<128x64xi32, #blocked> + %7 = tt.addptr %0, %6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + + %8 = ub.poison : tensor<128x64x!tt.ptr, #blocked> + // CHECK: scf.if + %9 = scf.if %arg2 -> (tensor<128x64x!tt.ptr, #blocked>) { + scf.yield %8 : tensor<128x64x!tt.ptr, #blocked> + } else { + scf.yield %7 : tensor<128x64x!tt.ptr, #blocked> + } + // CHECK: [[PTR:%.*]] = ttg.convert_layout %{{.*}} : tensor<128x64x!tt.ptr, #{{.*}}> -> tensor<128x64x!tt.ptr, [[COALESCED_LAYOUT]]> + // CHECK-NEXT: tt.load [[PTR]] + %10 = tt.load %9 : tensor<128x64x!tt.ptr, #blocked> + tt.return +} + +} diff --git a/test/TritonGPU/fuse-nested-loops.mlir b/test/TritonGPU/fuse-nested-loops.mlir index b4a0a5bd8942..bcaf031e468d 100644 --- a/test/TritonGPU/fuse-nested-loops.mlir +++ b/test/TritonGPU/fuse-nested-loops.mlir @@ -17,7 +17,7 @@ tt.func @no_fusion(%lb: index, %ub: index, %step: index) -> index { // CHECK-NEXT: yield scf.yield %1 : index // CHECK-NEXT: } - } + } {"ttg.always-fuse"} // CHECK-NEXT: after.loop "after.loop"() : () -> () tt.return %0 : index @@ -48,30 +48,34 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]] // T = -1 - // i = lbi + // i = lbi - stepi // j = None // for _ in range(total_iters): // - // CHECK: [[UNDEF_I64:%.*]] = ub.poison : i64 + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] // CHECK: scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( - // CHECK-SAME: [[T_ARG:%.*]] = %c-1_i64, [[I:%.*]] = [[LBI]], [[J_ARG:%.*]] = [[UNDEF_I64]]) -> (i64, i64, i64) : i64 { + // CHECK-SAME: [[T_ARG:%.*]] = %c-1_i64, [[I_ARG:%.*]] = [[I_INIT]], [[J_ARG:%.*]] = %c0_i64) -> (i64, i64, i64) : i64 { scf.for %i = %lbi to %ubi step %stepi : i64 { - // T = (T + 1) % inner_len + // T = 0 if T == (inner_len - 1) else T + 1 // // CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64 - // CHECK-NEXT: [[T:%.*]] = arith.remsi [[T_PLUS_1]], [[INNER_LEN]] + // CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 + // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]] + // CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]] // if T == 0: + // i += stepi // prologue(i) // j = lbj // // CHECK: [[START:%.*]] = arith.subi %c0_i64, %c0_i64 : i64 // CHECK-NEXT: [[PROLOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[START]] - // CHECK-NEXT: [[J:%.*]] = scf.if [[PROLOGUE_COND]] -> (i64) { + // CHECK-NEXT: [[JI:%.*]]:2 = scf.if [[PROLOGUE_COND]] -> (i64, i64) { + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] // CHECK-NEXT: "prologue"([[I]]) : (i64) -> () - // CHECK-NEXT: yield [[LBJ]] + // CHECK-NEXT: yield [[LBJ]], [[I]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[J_ARG]] + // CHECK-NEXT: yield [[J_ARG]], [[I_ARG]] // CHECK-NEXT: } "prologue"(%i) : (i64) -> () @@ -84,11 +88,11 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK-NEXT: [[LT:%.*]] = arith.cmpi slt, [[T]], [[END]] // CHECK-NEXT: [[COND:%.*]] = arith.andi [[GE]], [[LT]] // CHECK-NEXT: [[J_NEXT:%.*]] = scf.if [[COND]] -> (i64) { - // CHECK-NEXT: "body"([[I]], [[J]]) : (i64, i64) -> () - // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[J]], [[STEPJ]] + // CHECK-NEXT: "body"([[JI]]#1, [[JI]]#0) : (i64, i64) -> () + // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[JI]]#0, [[STEPJ]] // CHECK-NEXT: yield [[J_INCR]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[J]] + // CHECK-NEXT: yield [[JI]]#0 // CHECK-NEXT: } scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () @@ -98,19 +102,15 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // epilogue(i) // i += stepi // - // CHECK: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 // CHECK-NEXT: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]] - // CHECK-NEXT: [[I_NEXT:%.*]] = scf.if [[EPILOGUE_COND]] -> (i64) { - // CHECK-NEXT: "epilogue"([[I]]) : (i64) -> () - // CHECK-NEXT: [[I_INCR:%.*]] = arith.addi [[I]], [[STEPI]] - // CHECK-NEXT: yield [[I_INCR]] + // CHECK-NEXT: scf.if [[EPILOGUE_COND]] { + // CHECK-NEXT: "epilogue"([[JI]]#1) : (i64) -> () // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[I]] // CHECK-NEXT: } "epilogue"(%i) : (i64) -> () - // CHECK-NEXT: yield [[T]], [[I_NEXT]], [[J_NEXT]] : i64, i64, i64 - } + // CHECK-NEXT: yield [[T]], [[JI]]#1, [[J_NEXT]] : i64, i64, i64 + } {"ttg.always-fuse"} tt.return } @@ -118,32 +118,33 @@ tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64, [[LBJ:%.*]]: i64, [[UBJ:%.*]]: i64, [[STEPJ:%.*]]: i64 // CHECK-SAME: [[INOUT:%.*]]: index tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64, %inout: index) -> index { - // CHECK-DAG: [[UNDEF_I64:%.*]] = ub.poison : i64 - // CHECK-DAG: [[UNDEF_INDEX:%.*]] = ub.poison : index - // CHECK: [[OUTER_OUTS:%.*]]:7 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS:%.*]] step %c1_i64 iter_args( + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] + // CHECK: [[OUTER_OUTS:%.*]]:6 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS:%.*]] step %c1_i64 iter_args( // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, - // CHECK-SAME: [[I:%arg[0-9]+]] = [[LBI]] + // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]] // CHECK-SAME: [[M:%arg[0-9]+]] = [[INOUT]] - // CHECK-SAME: [[J_ARG:%arg[0-9]+]] = [[UNDEF_I64]] - // CHECK-SAME: [[K_ARG:%arg[0-9]+]] = [[UNDEF_INDEX]] - // CHECK-SAME: [[PROLOGUE_OUT_ARG:%arg[0-9]+]] = [[UNDEF_INDEX]] - // CHECK-SAME: [[EPILOGUE_OUT_ARG:%arg[0-9]+]] = [[UNDEF_INDEX]] - // CHECK-SAME: ) -> (i64, i64, index, i64, index, index, index) : i64 { + // CHECK-SAME: [[J_ARG:%arg[0-9]+]] = %c0_i64 + // CHECK-SAME: [[K_ARG:%arg[0-9]+]] = %c0 + // CHECK-SAME: [[PROLOGUE_OUT_ARG:%arg[0-9]+]] = %c0 + // CHECK-SAME: ) -> (i64, i64, index, i64, index, index) : i64 { %outer_out = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %inout) -> index : i64 { // if T == 0: + // i += stepi // prologue(i) // j = lbj // - // CHECK: [[PROLOGUE_OUTS:%.*]]:3 = scf.if %{{[0-9]+}} -> (i64, index, index) { + // CHECK: [[PROLOGUE_OUTS:%.*]]:4 = scf.if %{{[0-9]+}} -> (i64, index, index, i64) { + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] // CHECK-NEXT: [[PROLOGUE_RES:%.*]] = "prologue"([[I]], [[INOUT]], [[M]]) : (i64, index, index) -> index - // CHECK-NEXT: yield [[LBJ]], [[PROLOGUE_RES]], [[M]] + // CHECK-NEXT: yield [[LBJ]], [[PROLOGUE_RES]], [[M]], [[I]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[J_ARG]], [[PROLOGUE_OUT_ARG]], [[K_ARG]] + // CHECK-NEXT: yield [[J_ARG]], [[PROLOGUE_OUT_ARG]], [[K_ARG]], [[I_ARG]] // CHECK-NEXT: } // // J := [[PROLOGUE_OUTS]]#0 // PROLOGUE_OUT := [[PROLOGUE_OUTS]]#1 // K := [[PROLOGUE_OUTS]]#2 + // I := [[PROLOGUE_OUTS]]#3 %prologue_out = "prologue"(%i, %inout, %m) : (i64, index, index) -> index // if T >= 0 and T < len_j: @@ -151,7 +152,7 @@ tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // j += stepj // // CHECK: [[BODY_OUTS:%.*]]:2 = scf.if {{.*}} -> (i64, index) { - // CHECK-NEXT: [[BODY_OUT:%.*]] = "body"([[I]], [[PROLOGUE_OUTS]]#0, [[PROLOGUE_OUTS]]#2, [[PROLOGUE_OUTS]]#1, [[M]]) : (i64, i64, index, index, index) -> index + // CHECK-NEXT: [[BODY_OUT:%.*]] = "body"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#0, [[PROLOGUE_OUTS]]#2, [[PROLOGUE_OUTS]]#1, [[M]]) : (i64, i64, index, index, index) -> index // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[PROLOGUE_OUTS]]#0, [[STEPJ]] // CHECK-NEXT: yield [[J_INCR]], [[BODY_OUT]] // CHECK-NEXT: } else { @@ -166,18 +167,17 @@ tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ub // epilogue(i) // i += stepi // - // CHECK: [[EPILOGUE_OUTS:%.*]]:2 = scf.if {{.*}} -> (i64, index) { - // CHECK-NEXT: [[EPILOGUE_OUT:%.*]] = "epilogue"([[I]], [[PROLOGUE_OUTS]]#1, [[BODY_OUTS]]#1, [[M]]) : (i64, index, index, index) -> index - // CHECK-NEXT: [[I_INCR:%.*]] = arith.addi [[I]], [[STEPI]] - // CHECK-NEXT: yield [[I_INCR]], [[EPILOGUE_OUT]] + // CHECK: [[EPILOGUE_OUTS:%.*]] = scf.if {{.*}} -> (index) { + // CHECK-NEXT: [[EPILOGUE_OUT:%.*]] = "epilogue"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#1, [[BODY_OUTS]]#1, [[M]]) : (i64, index, index, index) -> index + // CHECK-NEXT: yield [[EPILOGUE_OUT]] // CHECK-NEXT: } else { - // CHECK-NEXT: yield [[I]], [[EPILOGUE_OUT_ARG]] + // CHECK-NEXT: yield [[M]] // CHECK-NEXT: } %epilogue_out = "epilogue"(%i, %prologue_out, %inner_out, %m) : (i64, index, index, index) -> index - // CHECK-NEXT: yield %{{.*}}, [[EPILOGUE_OUTS]]#0, [[EPILOGUE_OUTS]]#1, [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1, [[EPILOGUE_OUTS]]#1 : i64, i64, index, i64, index, index, index + // CHECK-NEXT: yield %{{.*}}, [[PROLOGUE_OUTS]]#3, [[EPILOGUE_OUTS]], [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1 : i64, i64, index, i64, index, index scf.yield %epilogue_out : index - } + } {"ttg.always-fuse"} // CHECK: return [[OUTER_OUTS]]#2 tt.return %outer_out : index } @@ -213,34 +213,35 @@ tt.func @multiple_loops( // CHECK: [[INNER_LEN:%.*]] = arith.subi [[PLEN3]], %c2_i64 // CHECK-NEXT: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]] - // CHECK: [[UNDEF_I64:%.*]] = ub.poison : i64 - // CHECK: [[UNDEF_F32:%.*]] = ub.poison : f32 - // CHECK: [[OUTS:%.*]]:13 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] + // CHECK: [[OUTS:%.*]]:12 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, - // CHECK-SAME: [[I:%arg[0-9]+]] = [[LBI]], + // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]], // CHECK-SAME: [[M:%arg[0-9]+]] = [[M0]], - // CHECK-SAME: [[J0_ARG:%arg[0-9]+]] = [[UNDEF_I64]], - // CHECK-SAME: [[J1_ARG:%arg[0-9]+]] = [[UNDEF_I64]], - // CHECK-SAME: [[J2_ARG:%arg[0-9]+]] = [[UNDEF_I64]], - // CHECK-SAME: [[BODY0_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[BODY1_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[BODY2_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[PROLOGUE0_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[PROLOGUE1_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = [[UNDEF_F32]], - // CHECK-SAME: [[EPILOGUE_ARG:%arg[0-9]+]] = [[UNDEF_F32]]) + // CHECK-SAME: [[J0_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[J1_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[J2_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[BODY0_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[BODY1_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[BODY2_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE0_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE1_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = %cst) %mN = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %m0) -> f32 : i64 { // CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64 - // CHECK-NEXT: [[T:%.*]] = arith.remsi [[T_PLUS_1]], [[INNER_LEN]] + // CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 + // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]] + // CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]] // CHECK: [[START0:%.*]] = arith.subi [[PLEN0]], %c0_i64 // CHECK-NEXT: [[PROLOGUE_COND0:%.*]] = arith.cmpi eq, [[T]], [[START0]] - // CHECK-NEXT: [[PROLOGUE0_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND0]] + // CHECK-NEXT: [[PROLOGUE0_OUTS:%.*]]:4 = scf.if [[PROLOGUE_COND0]] + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] // CHECK-NEXT: [[RES:%.*]] = "prologue0"([[I]], [[M]]) - // CHECK-NEXT: yield [[LBJ0]], [[RES]], [[RES]] + // CHECK-NEXT: yield [[LBJ0]], [[RES]], [[RES]], [[I]] // CHECK-NEXT: else - // CHECK-NEXT: yield [[J0_ARG]], [[PROLOGUE0_ARG]], [[BODY0_ARG]] + // CHECK-NEXT: yield [[J0_ARG]], [[PROLOGUE0_ARG]], [[BODY0_ARG]], [[I_ARG]] %k00 = "prologue0"(%i, %m) : (i64, f32) -> f32 // CHECK: [[END0:%.*]] = arith.addi [[START0]], [[LEN_J0]] @@ -248,7 +249,7 @@ tt.func @multiple_loops( // CHECK-NEXT: [[LT0:%.*]] = arith.cmpi slt, [[T]], [[END0]] // CHECK-NEXT: [[BODY_COND0:%.*]] = arith.andi [[GE0]], [[LT0]] // CHECK-NEXT: [[BODY0_OUTS:%.*]]:2 = scf.if [[BODY_COND0]] - // CHECK-NEXT: [[RES:%.*]] = "body0"([[I]], [[PROLOGUE0_OUTS]]#0, [[PROLOGUE0_OUTS]]#2) + // CHECK-NEXT: [[RES:%.*]] = "body0"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE0_OUTS]]#0, [[PROLOGUE0_OUTS]]#2) // CHECK-NEXT: [[NEXT_J0:%.*]] = arith.addi [[PROLOGUE0_OUTS]]#0, [[STEPJ0]] // CHECK-NEXT: yield [[NEXT_J0]], [[RES]] // CHECK-NEXT: else @@ -261,7 +262,7 @@ tt.func @multiple_loops( // CHECK: [[START1:%.*]] = arith.subi [[PLEN1]], %c1_i64 // CHECK-NEXT: [[PROLOGUE_COND1:%.*]] = arith.cmpi eq, [[T]], [[START1]] // CHECK-NEXT: [[PROLOGUE1_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND1]] - // CHECK-NEXT: [[RES:%.*]] = "prologue1"([[I]], [[BODY0_OUTS]]#1) + // CHECK-NEXT: [[RES:%.*]] = "prologue1"([[PROLOGUE0_OUTS]]#3, [[BODY0_OUTS]]#1) // CHECK-NEXT: yield [[LBJ1]], [[RES]], [[RES]] // CHECK-NEXT: else // CHECK-NEXT: yield [[J1_ARG]], [[PROLOGUE1_ARG]], [[BODY1_ARG]] @@ -272,7 +273,7 @@ tt.func @multiple_loops( // CHECK-NEXT: [[LT1:%.*]] = arith.cmpi slt, [[T]], [[END1]] // CHECK-NEXT: [[BODY_COND1:%.*]] = arith.andi [[GE1]], [[LT1]] // CHECK-NEXT: [[BODY1_OUTS:%.*]]:2 = scf.if [[BODY_COND1]] - // CHECK-NEXT: [[RES:%.*]] = "body1"([[I]], [[PROLOGUE1_OUTS]]#0, [[PROLOGUE1_OUTS]]#2) + // CHECK-NEXT: [[RES:%.*]] = "body1"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE1_OUTS]]#0, [[PROLOGUE1_OUTS]]#2) // CHECK-NEXT: [[NEXT_J1:%.*]] = arith.addi [[PROLOGUE1_OUTS]]#0, [[STEPJ1]] // CHECK-NEXT: yield [[NEXT_J1]], [[RES]] // CHECK-NEXT: else @@ -285,7 +286,7 @@ tt.func @multiple_loops( // CHECK: [[START2:%.*]] = arith.subi [[PLEN2]], %c2_i64 // CHECK-NEXT: [[PROLOGUE_COND2:%.*]] = arith.cmpi eq, [[T]], [[START2]] // CHECK-NEXT: [[PROLOGUE2_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND2]] - // CHECK-NEXT: [[RES:%.*]] = "prologue2"([[I]], [[BODY1_OUTS]]#1) + // CHECK-NEXT: [[RES:%.*]] = "prologue2"([[PROLOGUE0_OUTS]]#3, [[BODY1_OUTS]]#1) // CHECK-NEXT: yield [[LBJ2]], [[RES]], [[RES]] // CHECK-NEXT: else // CHECK-NEXT: yield [[J2_ARG]], [[PROLOGUE2_ARG]], [[BODY2_ARG]] @@ -296,7 +297,7 @@ tt.func @multiple_loops( // CHECK-NEXT: [[LT2:%.*]] = arith.cmpi slt, [[T]], [[END2]] // CHECK-NEXT: [[BODY_COND2:%.*]] = arith.andi [[GE2]], [[LT2]] // CHECK-NEXT: [[BODY2_OUTS:%.*]]:2 = scf.if [[BODY_COND2]] - // CHECK-NEXT: [[RES:%.*]] = "body2"([[I]], [[PROLOGUE2_OUTS]]#0, [[PROLOGUE2_OUTS]]#2) + // CHECK-NEXT: [[RES:%.*]] = "body2"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE2_OUTS]]#0, [[PROLOGUE2_OUTS]]#2) // CHECK-NEXT: [[NEXT_J2:%.*]] = arith.addi [[PROLOGUE2_OUTS]]#0, [[STEPJ2]] // CHECK-NEXT: yield [[NEXT_J2]], [[RES]] // CHECK-NEXT: else @@ -306,21 +307,19 @@ tt.func @multiple_loops( scf.yield %res : f32 } - // CHECK: [[END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 - // CHECK-NEXT: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[END]] - // CHECK-NEXT: [[EPILOGUE_OUTS:%.*]]:2 = scf.if [[EPILOGUE_COND]] - // CHECK-NEXT: [[RES:%.*]] = "epilogue"([[I]], [[BODY2_OUTS]]#1) - // CHECK-NEXT: [[I_INCR:%.*]] = arith.addi [[I]], [[STEPI]] - // CHECK-NEXT: yield [[I_INCR]], [[RES]] + // CHECK: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]] + // CHECK-NEXT: [[EPILOGUE_OUTS:%.*]] = scf.if [[EPILOGUE_COND]] + // CHECK-NEXT: [[RES:%.*]] = "epilogue"([[PROLOGUE0_OUTS]]#3, [[BODY2_OUTS]]#1) + // CHECK-NEXT: yield [[RES]] // CHECK-NEXT: else - // CHECK-NEXT: yield [[I]], [[EPILOGUE_ARG]] + // CHECK-NEXT: yield [[M]] %out = "epilogue"(%i, %k2N) : (i64, f32) -> f32 - // CHECK: scf.yield [[T]], [[EPILOGUE_OUTS]]#0, [[EPILOGUE_OUTS]]#1, + // CHECK: scf.yield [[T]], [[PROLOGUE0_OUTS]]#3, [[EPILOGUE_OUTS]], // CHECK-SAME: [[BODY0_OUTS]]#0, [[BODY1_OUTS]]#0, [[BODY2_OUTS]]#0, - // CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1, [[EPILOGUE_OUTS]]#1 : + // CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1 : scf.yield %out : f32 - } + } {"ttg.always-fuse"} // CHECK: return [[OUTS]]#2 tt.return %mN : f32 } @@ -332,12 +331,12 @@ tt.func @two_loop_nests(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} scf.for %i = %lbi to %ubi step %stepi : i64 { scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} // CHECK-NOT: scf.for // CHECK: tt.return tt.return @@ -360,16 +359,16 @@ tt.func @hoist_loop_bound_computations(%lbi: i64, %ubi: i64, %stepi: i64) { %lbj = arith.addi %lbi, %stepi : i64 %ubj = arith.addi %ubi, %stepi : i64 %stepj = arith.addi %stepi, %stepi : i64 - // CHECK: [[J:%.*]] = scf.if - // CHECK-NEXT: yield [[LBJ]] + // CHECK: [[J:%.*]]:2 = scf.if + // CHECK: yield [[LBJ]] // CHECK: scf.if // CHECK-NEXT: "body" - // CHECK-NEXT: arith.addi [[J]], [[STEPJ]] + // CHECK-NEXT: arith.addi [[J]]#0, [[STEPJ]] scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} tt.return } @@ -383,25 +382,25 @@ tt.func @cannot_fuse(%lbi: i64, %ubi: i64, %stepi: i64) { scf.for %j = %lbj to %ubj step %stepj : i64 { "body"(%i, %j) : (i64, i64) -> () } - } + } {"ttg.always-fuse"} + // CHECK-NOT: scf.for tt.return } // CHECK-LABEL: @upcast_i16_to_i32 -// CHECK-SAME: [[LBI:%.*]]: i16, [[UBI:%.*]]: i16, [[STEPI:%.*]]: i16, [[LBJ:%.*]]: i16, [[UBJ:%.*]]: i16, [[STEPJ:%.*]]: i16 -tt.func @upcast_i16_to_i32(%lbi: i16, %ubi: i16, %stepi: i16, %lbj: i16, %ubj: i16, %stepj: i16) { - // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : i16 - // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : i16 +// CHECK-SAME: [[LBI:%.*]]: i32, [[UBI:%.*]]: i32, [[STEPI:%.*]]: i32, [[LBJ:%.*]]: i16, [[UBJ:%.*]]: i16, [[STEPJ:%.*]]: i16 +tt.func @upcast_i16_to_i32(%lbi: i32, %ubi: i32, %stepi: i32, %lbj: i16, %ubj: i16, %stepj: i16) { + // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : i32 + // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : i32 // CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] : i16 // CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] : i16 // CHECK: arith.extsi [[LEN_J]] : i16 to i32 - // CHECK: arith.extsi [[LEN_I]] : i16 to i32 - scf.for %i = %lbi to %ubi step %stepi : i16 { + scf.for %i = %lbi to %ubi step %stepi : i32 { scf.for %j = %lbj to %ubj step %stepj : i16 { - "body"(%i, %j) : (i16, i16) -> () + "body"(%i, %j) : (i32, i16) -> () } - } + } {"ttg.always-fuse"} tt.return } @@ -419,7 +418,7 @@ tt.func @upcast_index_to_i64(%lbi: index, %ubi: index, %stepi: index, %lbj: inde scf.for %j = %lbj to %ubj step %stepj { "body"(%i, %j) : (index, index) -> () } - } + } {"ttg.always-fuse"} tt.return } @@ -435,8 +434,109 @@ tt.func @triple_loop_nest( "body"(%i, %j, %k) : (i64, i64, i64) -> () } } - } + } {"ttg.always-fuse"} // CHECK-NOT: scf.for // CHECK: tt.return tt.return } + +// CHECK-LABEL: @preserve_stage_count +tt.func @preserve_stage_count(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK-COUNT-1: scf.for + scf.for %i = %lb to %ub step %c1_i32 : i32 { + scf.for %j = %lb to %ub step %c1_i32 : i32 { + "body"(%j) : (i32) -> () + scf.yield + } {tt.num_stages = 4 : i32} + scf.for %j = %lb to %ub step %c1_i32 : i32 { + "body"(%j) : (i32) -> () + scf.yield + } {tt.num_stages = 6 : i32} + } {"ttg.always-fuse"} + // CHECK: tt.num_stages = 6 : i32 + // CHECK-NOT: scf.for + tt.return +} + +// CHECK-LABEL: @fuse_attr_speculate +// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32 +tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK: [[DIFF:%.*]] = arith.subi [[UB]], [[LB]] + // CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32 + // CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32 + + // CHECK: scf.if [[IS_ZERO]] + // CHECK-NEXT: scf.for %{{.*}} = [[LB]] to [[UB]] step %c1_i32 + // CHECK-NEXT: "prologue" + // CHECK-NXET: } + + // CHECK: else + // CHECK-COUNT-1: scf.for + // CHECK-NOT: scf.for + scf.for %i = %lb to %ub step %c1_i32 : i32 { + // CHECK: "prologue" + "prologue"(%i) : (i32) -> () + // CHECK: scf.if %true + scf.for %j = %lb to %ub step %c1_i32 : i32 { + // CHECK-NEXT: "body" + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + } {tt.flatten} + tt.return +} + +// CHECK-LABEL: @speculate_hoist +// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32 +tt.func @speculate_hoist(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK: [[UBJ:%.*]] = arith.addi [[LB]], [[UB]] + // CHECK: [[DIFF:%.*]] = arith.subi [[UBJ]], [[LB]] + // CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32 + // CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32 + + // CHECK: scf.if [[IS_ZERO]] + scf.for %i = %lb to %ub step %c1_i32 : i32 { + "prologue"(%i) : (i32) -> () + %ubj = arith.addi %lb, %ub : i32 + scf.for %j = %lb to %ubj step %c1_i32 : i32 { + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + } {tt.flatten} + tt.return +} + +// CHECK-LABEL: @sink_prologue_to_epilogue +// CHECK-SAME: [[UB:%.*]]: i32 +tt.func @sink_prologue_to_epilogue(%ub: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK: else + // CHECK: scf.for + %0 = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%k = %c0_i32) -> i32 : i32 { + // CHECK: [[PROLOGUE_OUTS:%.*]]:2 = scf.if + %0 = arith.addi %i, %ub : i32 + // CHECK: scf.if %true + // CHECK-NEXT: "body" + scf.for %j = %c0_i32 to %ub step %c1_i32 : i32 { + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + // CHECK: scf.if + // CHECK-NEXT: [[V0:%.*]] = arith.addi [[PROLOGUE_OUTS]]#1, [[UB]] + // CHECK-NEXT: [[V1:%.*]] = arith.addi [[V0]], [[UB]] + %1 = arith.addi %0, %ub : i32 + // CHECK-NEXT: "epilogue"([[V1]]) + "epilogue"(%1) : (i32) -> () + scf.yield %0 : i32 + } {tt.flatten} + + tt.return +} diff --git a/test/TritonGPU/pipeline-loop-nest.mlir b/test/TritonGPU/pipeline-loop-nest.mlir new file mode 100644 index 000000000000..c4f9dc5f62c1 --- /dev/null +++ b/test/TritonGPU/pipeline-loop-nest.mlir @@ -0,0 +1,81 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(convert-triton-to-tritongpu{num-warps=4 target=cuda:100},tritongpu-coalesce,tritongpu-accelerate-matmul,tritongpu-remove-layout-conversions,tritongpu-optimize-dot-operands,cse,tritongpu-fuse-nested-loops,canonicalize,tritongpu-optimize-accumulator-init,tritongpu-loop-scheduling,tritongpu-pipeline,canonicalize)' | FileCheck %s --check-prefix=BLACKWELL +// RUN: triton-opt %s -pass-pipeline='builtin.module(convert-triton-to-tritongpu{num-warps=4 target=cuda:90 },tritongpu-coalesce,tritongpu-accelerate-matmul,tritongpu-remove-layout-conversions,tritongpu-optimize-dot-operands,cse,tritongpu-fuse-nested-loops,canonicalize,tritongpu-optimize-accumulator-init,canonicalize,tritongpu-combine-tensor-select-and-if,tritongpu-loop-scheduling,tritongpu-pipeline,canonicalize)' | FileCheck %s --check-prefix=HOPPER + +// BLACKWELL-LABEL: @matmul_kernel_tma_persistent +// HOPPER-LABEL: @matmul_kernel_tma_persistent +tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %c63_i32 = arith.constant 63 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c132_i32 = arith.constant 132 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c127_i32 : i32 + %4 = arith.divsi %3, %c128_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.subi %0, %c132_i32 : i32 + %9 = arith.muli %4, %c8_i32 : i32 + + // BLACKWELL: [[ACC_BUFS:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, + // BLACKWELL: ttg.memdesc_trans + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]], %false + + // BLACKWELL: scf.for + %10 = scf.for %arg6 = %0 to %7 step %c132_i32 iter_args(%arg7 = %8) -> (i32) : i32 { + %11 = arith.divsi %arg6, %9 : i32 + %12 = arith.muli %11, %c8_i32 : i32 + %13 = arith.subi %2, %12 : i32 + %14 = arith.minsi %13, %c8_i32 : i32 + %15 = arith.remsi %arg6, %14 : i32 + %16 = arith.addi %12, %15 : i32 + %17 = arith.remsi %arg6, %9 : i32 + %18 = arith.divsi %17, %14 : i32 + %19 = arith.muli %16, %c128_i32 : i32 + %20 = arith.muli %18, %c128_i32 : i32 + %21 = scf.for %arg8 = %c0_i32 to %6 step %c1_i32 iter_args(%arg9 = %cst) -> (tensor<128x128xf32>) : i32 { + %35 = arith.muli %arg8, %c64_i32 : i32 + %36 = tt.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> + %37 = tt.experimental_descriptor_load %36[%19, %35] : !tt.tensordesc> -> tensor<128x64xf16> + %38 = tt.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> + %39 = tt.experimental_descriptor_load %38[%20, %35] : !tt.tensordesc> -> tensor<128x64xf16> + // BLACKWELL: ttg.memdesc_trans + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]], %arg + + // HOPPER: [[RESULT:%.*]] = ttng.warp_group_dot {{.*}} isAsync = true + // HOPPER-NEXT: ttng.warp_group_dot_wait [[RESULT]], {{.*}} {pendings = 1 : i32} + %40 = tt.trans %39 {order = array} : tensor<128x64xf16> -> tensor<64x128xf16> + %41 = tt.dot %37, %40, %arg9, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x128xf16> -> tensor<128x128xf32> + scf.yield %41 : tensor<128x128xf32> + } + // BLACKWELL-COUNT-1: ttng.tmem_load + // BLACKWELL-NOT: ttng.tmem_load + + // HOPPER: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + %22 = arith.addi %arg7, %c132_i32 : i32 + %23 = arith.divsi %22, %9 : i32 + %24 = arith.muli %23, %c8_i32 : i32 + %25 = arith.subi %2, %24 : i32 + %26 = arith.minsi %25, %c8_i32 : i32 + %27 = arith.remsi %22, %26 : i32 + %28 = arith.addi %24, %27 : i32 + %29 = arith.remsi %22, %9 : i32 + %30 = arith.divsi %29, %26 : i32 + %31 = arith.muli %28, %c128_i32 : i32 + %32 = arith.muli %30, %c128_i32 : i32 + %33 = arith.truncf %21 : tensor<128x128xf32> to tensor<128x128xf16> + %34 = tt.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> + tt.experimental_descriptor_store %34[%31, %32], %33 : !tt.tensordesc>, tensor<128x128xf16> + scf.yield %22 : i32 + } {tt.flatten} + tt.return +} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 561b65ec5471..ecfeb3c91607 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -231,7 +231,6 @@ def make_ttir(mod, metadata, opt): passes.ttir.add_combine(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) - passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) pm.run(mod) @@ -260,11 +259,18 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.common.add_cse(pm) if capability // 10 in [8, 9]: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_licm(pm) passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.common.add_canonicalizer(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) - if capability // 10 >= 10: + elif capability // 10 >= 10: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_licm(pm) passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) @@ -272,6 +278,8 @@ def make_ttgir(mod, metadata, opt, capability): nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm) passes.common.add_canonicalizer(pm) + else: + passes.common.add_licm(pm) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.ttgpuir.add_coalesce_async_copy(pm)