From 368f041837caee254cbb822b669640af3e1f637c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 04:33:20 +0900 Subject: [PATCH 01/17] add empty interface --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 008103d81c21..b5ea4b8b0f29 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -149,8 +149,26 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// +def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ + }]; + + let methods = [ + ]; +} + +class SharedEncoding traits = [], + Dialect dialect = TritonGPU_Dialect> + : TritonGPU_Attr<"SharedEncoding", "shared_encoding", !listconcat([SharedEncodingTrait], traits), dialect> { + + let description = [{ + }]; +} -def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> { +// def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> { +def SharedEncodingAttr : SharedEncoding { let mnemonic = "shared"; let description = [{ From 60bfa1b12eb26a711a64115e40505965b221bcb7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 06:39:26 +0900 Subject: [PATCH 02/17] populate interface methods --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 40 ++++++++++++++++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 13 +++++- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index b5ea4b8b0f29..166a81aa1f33 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -156,6 +156,24 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { }]; let methods = [ + InterfaceMethod<"", + "unsigned", + "getVec">, + InterfaceMethod<"", + "unsigned", + "getPerPhase">, + InterfaceMethod<"", + "unsigned", + "getMaxPhase">, + InterfaceMethod<"", + "ArrayRef", + "getOrder">, + // InterfaceMethod<"", + // "CTALayoutAttr", + // "getCTALayout">, + InterfaceMethod<"", + "bool", + "getHasLeadingOffset"> ]; } @@ -165,6 +183,15 @@ class SharedEncoding traits = [], let description = [{ }]; + + code extraSharedEncodingDeclaration = extraBaseClassDeclaration # [{ + unsigned getVec() const; + unsigned getPerPhase() const; + unsigned getMaxPhase() const; + ArrayRef getOrder() const; + // CTALayoutAttr getCTALayout(); + bool getHasLeadingOffset() const; + }]; } // def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> { @@ -257,12 +284,12 @@ compared to 1*64 when the hasLeadingOffset is false. // order: the fastest-changing axis first let parameters = ( ins - "unsigned":$vec, - "unsigned":$perPhase, - "unsigned":$maxPhase, - ArrayRefParameter<"unsigned">:$order, + "unsigned":$vec__, + "unsigned":$perPhase__, + "unsigned":$maxPhase__, + ArrayRefParameter<"unsigned">:$order__, "CTALayoutAttr":$CTALayout, - "bool":$hasLeadingOffset + "bool":$hasLeadingOffset__ ); let builders = [ @@ -450,7 +477,8 @@ compared to 1*64 when the hasLeadingOffset is false. }]> ]; - let extraClassDeclaration = extraBaseClassDeclaration # [{ + let extraClassDeclaration = extraSharedEncodingDeclaration # [{ + // TODO: Should this be part of the interface? int32_t getAlignment() const; }]; let hasCustomAssemblyFormat = 1; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 0ada09159a4f..9b1d082bb7fe 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -18,6 +18,7 @@ #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" // Include TableGen'erated code @@ -885,6 +886,16 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, llvm_unreachable("getElemsPerThread is not supported for shared layout"); return 0; } + +unsigned SharedEncodingAttr::getVec() const { return getVec__(); } +unsigned SharedEncodingAttr::getPerPhase() const { return getPerPhase__(); } +unsigned SharedEncodingAttr::getMaxPhase() const { return getMaxPhase__(); } +ArrayRef SharedEncodingAttr::getOrder() const { return getOrder__(); } + +bool SharedEncodingAttr::getHasLeadingOffset() const { + return getHasLeadingOffset__(); +} + int32_t SharedEncodingAttr::getAlignment() const { if (getHasLeadingOffset()) return 128 * getMaxPhase(); @@ -1872,7 +1883,7 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { << "vec = " << getVec() // << ", perPhase = " << getPerPhase() << ", maxPhase = " << getMaxPhase() // - << ", order = [" << getOrder() << "]"; + << ", order = [" << ArrayRef(getOrder()) << "]"; maybePrintCTALayout(getContext(), printer, getCTALayout(), /*rank=*/getOrder().size()); printer << ", hasLeadingOffset = " << getHasLeadingOffset() << "}>"; From 28556f5300363240d98e5c57c85413df0f02bc3b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 08:51:45 +0900 Subject: [PATCH 03/17] removed hardcoded SharedEncodingAttr::getCTALayout() --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 829b13f719fc..b72635350061 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -127,7 +127,7 @@ LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, if (rank == 1) { return combineCtaCgaWithShape( LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), - shared.getCTALayout(), shape); + getCTALayout(shared), shape); } auto outDimNames = standardOutDimNames(ctx, rank); @@ -163,7 +163,7 @@ LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); } - return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape); + return combineCtaCgaWithShape(ctaLayout, getCTALayout(shared), shape); } } // namespace @@ -180,7 +180,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, // TODO: Not sure if this is correct. return combineCtaCgaWithShape( LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), - shared.getCTALayout(), shape); + getCTALayout(shared), shape); } int tileWidthBytes; @@ -244,7 +244,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); } - return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); + return combineCtaCgaWithShape(tileLayout, getCTALayout(shared), shape); } /// Function to generate lane and warp layout for dot operands. From 3147ef696e9a11ac3265d71c779dd85570682cab Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 09:08:26 +0900 Subject: [PATCH 04/17] add cta layout stuff into interface --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 17 +++++++-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 38 +++++++++++++------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 166a81aa1f33..1b053bd66d0e 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -168,9 +168,15 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { InterfaceMethod<"", "ArrayRef", "getOrder">, - // InterfaceMethod<"", - // "CTALayoutAttr", - // "getCTALayout">, + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA">, + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum">, + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder">, InterfaceMethod<"", "bool", "getHasLeadingOffset"> @@ -188,8 +194,11 @@ class SharedEncoding traits = [], unsigned getVec() const; unsigned getPerPhase() const; unsigned getMaxPhase() const; + // TODO: make ArrayRef vs SmallVector consistent ArrayRef getOrder() const; - // CTALayoutAttr getCTALayout(); + SmallVector getCTAsPerCGA() const; + SmallVector getCTASplitNum() const; + SmallVector getCTAOrder() const; bool getHasLeadingOffset() const; }]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9b1d082bb7fe..35f6cb41796b 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -11,6 +11,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Dialect/TritonGPU/IR/Types.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -20,6 +21,7 @@ #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" // Include TableGen'erated code #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" @@ -307,8 +309,10 @@ CTALayoutAttr getCTALayout(Attribute layout) { return CTALayoutAttr::get( layout.getContext(), getCTAsPerCGA(distributedLayout), getCTASplitNum(distributedLayout), getCTAOrder(distributedLayout)); - } else if (auto sharedLayout = mlir::dyn_cast(layout)) - return sharedLayout.getCTALayout(); + } else if (auto sharedLayout = mlir::dyn_cast(layout)) + return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(sharedLayout), + getCTASplitNum(sharedLayout), + getCTAOrder(sharedLayout)); else llvm::report_fatal_error("Unimplemented usage of getCTALayout"); return {}; @@ -318,8 +322,8 @@ SmallVector getCTAsPerCGA(Attribute layout) { ArrayRef ref; if (auto distributedLayout = mlir::dyn_cast(layout)) return distributedLayout.getCTAsPerCGA(); - else if (auto sharedLayout = mlir::dyn_cast(layout)) - ref = sharedLayout.getCTALayout().getCTAsPerCGA(); + else if (auto sharedLayout = mlir::dyn_cast(layout)) + ref = sharedLayout.getCTAsPerCGA(); else llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); return SmallVector(ref.begin(), ref.end()); @@ -330,9 +334,8 @@ SmallVector getCTASplitNum(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { return distributedLayout.getCTASplitNum(); - } else if (auto sharedLayout = mlir::dyn_cast(layout)) { - res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), - sharedLayout.getCTALayout().getCTASplitNum().end()); + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + return sharedLayout.getCTASplitNum(); } else if (auto tmemLayout = mlir::dyn_cast( layout)) { @@ -351,16 +354,15 @@ SmallVector getCTASplitNum(Attribute layout) { } SmallVector getCTAOrder(Attribute layout) { - SmallVector res; if (auto distributedLayout = mlir::dyn_cast(layout)) { - res = distributedLayout.getCTAOrder(); - } else if (auto sharedLayout = mlir::dyn_cast(layout)) { - res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); + return distributedLayout.getCTAOrder(); + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + return sharedLayout.getCTAOrder(); } else { llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); } - return res; + llvm_unreachable("Unhandled case"); } SmallVector getShapePerCTA(ArrayRef CTASplitNum, @@ -892,6 +894,18 @@ unsigned SharedEncodingAttr::getPerPhase() const { return getPerPhase__(); } unsigned SharedEncodingAttr::getMaxPhase() const { return getMaxPhase__(); } ArrayRef SharedEncodingAttr::getOrder() const { return getOrder__(); } +SmallVector SharedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} + +SmallVector SharedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} + +SmallVector SharedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + bool SharedEncodingAttr::getHasLeadingOffset() const { return getHasLeadingOffset__(); } From b2033e0bb6b27dc020a19d1e51941d9ef2421bd3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 09:12:57 +0900 Subject: [PATCH 05/17] Return SmallVector from getOrder --- include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 5 ++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 ++++- .../lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 1b053bd66d0e..a096abe4c6e3 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -166,7 +166,7 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { "unsigned", "getMaxPhase">, InterfaceMethod<"", - "ArrayRef", + "SmallVector", "getOrder">, InterfaceMethod<"Get the shape of the CTAs per CGA.", "SmallVector", @@ -194,8 +194,7 @@ class SharedEncoding traits = [], unsigned getVec() const; unsigned getPerPhase() const; unsigned getMaxPhase() const; - // TODO: make ArrayRef vs SmallVector consistent - ArrayRef getOrder() const; + SmallVector getOrder() const; SmallVector getCTAsPerCGA() const; SmallVector getCTASplitNum() const; SmallVector getCTAOrder() const; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 35f6cb41796b..3ba9684afe2d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -892,7 +892,10 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, unsigned SharedEncodingAttr::getVec() const { return getVec__(); } unsigned SharedEncodingAttr::getPerPhase() const { return getPerPhase__(); } unsigned SharedEncodingAttr::getMaxPhase() const { return getMaxPhase__(); } -ArrayRef SharedEncodingAttr::getOrder() const { return getOrder__(); } + +SmallVector SharedEncodingAttr::getOrder() const { + return SmallVector(getOrder__()); +} SmallVector SharedEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h index e0c852ebd040..471b513f4f78 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h @@ -58,7 +58,7 @@ class DotOpMmaV3SmemLoader : public DotOpMmaMemLoader { bool trans; Value elemsPerSwizzlingRowVal; SmallVector instrShape; - ArrayRef ord; + SmallVector ord; int elemsPerSwizzlingRow; int64_t elemBits; Value descriptor; From 47fa1ff3125da07707d699228a2d6661a61b8f44 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 13:10:18 +0900 Subject: [PATCH 06/17] replace isa --- lib/Analysis/Allocation.cpp | 4 ++-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 4 ++-- lib/Dialect/TritonGPU/IR/Ops.cpp | 6 +++--- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 2 +- .../Transforms/OptimizeDotOperands.cpp | 4 ++-- .../Pipeliner/MatmulLoopPipeline.cpp | 2 +- .../Transforms/ReduceDataDuplication.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- .../DotOpToLLVM/MMAv5.cpp | 4 ++-- .../DotOpToLLVM/WGMMA.cpp | 18 +++++++++--------- .../TensorMemoryToLLVM.cpp | 2 +- 11 files changed, 25 insertions(+), 25 deletions(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 9e45ebe6aa00..2663797b7cd3 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -165,8 +165,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { auto dstTy = cvtLayout.getType(); auto srcEncoding = srcTy.getEncoding(); auto dstEncoding = dstTy.getEncoding(); - if (mlir::isa(srcEncoding) || - mlir::isa(dstEncoding)) { + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { // Conversions from/to shared memory do not need scratch memory. return 0; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3ba9684afe2d..e143444d5721 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2587,7 +2587,7 @@ struct TritonGPUInferLayoutInterface auto mmaRetEncoding = mlir::dyn_cast(retEncoding); if (mmaRetEncoding && mmaRetEncoding.isHopper()) { auto dotOpEnc = mlir::dyn_cast(operandEncoding); - if (!mlir::isa(operandEncoding) && + if (!mlir::isa(operandEncoding) && !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && mlir::isa(dotOpEnc.getParent()))) { return emitOptionalError( @@ -3023,7 +3023,7 @@ struct TritonGPUVerifyTensorLayoutInterface LogicalResult verifyTensorLayout( Attribute layout, RankedTensorType rankedTy, ModuleOp module, function_ref makeErr) const override { - if (isa(layout)) + if (isa(layout)) return makeErr() << "Shared layout is not allowed on tensor type."; // TODO(jlebar): Currently this only checks blocked layouts, but other // layouts also have invariants! diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 9222cce69eb6..5e3566f47c83 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -218,7 +218,7 @@ struct CanonicalizeConvertFromConvert return failure(); // for hopper MMAv3 - if (mlir::isa(dstType.getEncoding()) && + if (mlir::isa(dstType.getEncoding()) && mlir::isa(srcType.getEncoding()) && llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { return dot->hasTrait(); @@ -606,11 +606,11 @@ LogicalResult MemDescSubviewOp::verify() { return emitError("src and result must both have or not have an encoding"); } - if (!isa(srcEnc) && + if (!isa(srcEnc) && !isa(srcEnc)) { return emitError("src encoding must be SharedEncodingAttr"); } - if (!isa(dstEnc) && + if (!isa(dstEnc) && !isa(srcEnc)) { return emitError("result encoding must be SharedEncodingAttr"); } diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index b3814329ae72..db0ef0dea193 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -120,7 +120,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { for (auto operand : op->getOperands()) { auto tensorType = dyn_cast(operand.getType()); if (tensorType && - !isa(tensorType.getEncoding())) { + !isa(tensorType.getEncoding())) { Type newType = getNewType(tensorType, encoding); newArgs.push_back(builder.create( op->getLoc(), newType, operand)); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5b268b154241..4f6a59093eb3 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -385,7 +385,7 @@ struct MMAV3UseRegOperand return cast(v.getType()).getEncoding(); }; - if (!isa(getEncoding(dotOp.getOperand(0)))) + if (!isa(getEncoding(dotOp.getOperand(0)))) return failure(); auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); auto dstEnc = @@ -444,7 +444,7 @@ struct MMAV3HoistLayoutConversion return cast(v.getType()).getEncoding(); }; - if (!isa(getEncoding(dotOp.getOperand(0)))) + if (!isa(getEncoding(dotOp.getOperand(0)))) return rewriter.notifyMatchFailure( dotOp, "requires Shared encoding for operand A"); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 47ed1b232423..3ff045937b05 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1429,7 +1429,7 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, // Rule 1: All shmem operands are multi-buffered. auto checkOperand = [&](Value operand) { - if (!isa( + if (!isa( cast(operand.getType()).getEncoding())) { // Rule 1a: Register operands must not be modified within the loop. // First, check for chained WGMMA as an exception. diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index af756c6d83e9..24261ca15783 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -36,7 +36,7 @@ class TritonGPUReduceDataDuplicationPass auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); auto srcEncoding = srcType.getEncoding(); - if (isa(srcEncoding)) + if (isa(srcEncoding)) return; auto dstDotOp = dyn_cast(dstType.getEncoding()); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 5a13a64535f8..2b4aa98ad647 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -284,7 +284,7 @@ std::string GraphLayoutMarker::getColor(const Type &type) const { return "lightslateblue"; else if (isa(layout)) return "orange"; - else if (isa(layout)) + else if (isa(layout)) return "orangered"; else { llvm::report_fatal_error("Unrecognized layout"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index eff746d2a0fa..c57c3bf1824a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -435,10 +435,10 @@ struct TCGen5MMAOpConversion ConversionPatternRewriter &rewriter) const override { auto AEnc = op.getA().getType().getEncoding(); auto BEnc = op.getB().getType().getEncoding(); - assert(mlir::isa(AEnc) || + assert(mlir::isa(AEnc) || mlir::isa(AEnc) && "Operand A should use Shared or Tensor memory layout."); - assert(mlir::isa(BEnc) && + assert(mlir::isa(BEnc) && "Operand B should use Shared layout."); assert(op.getBarrier() && "tensorcore op should have a barrier at this point."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 2c20942e155f..9bf749171557 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -34,7 +34,7 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; -using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingTrait; triton::nvgpu::WGMMAEltType getMmaRetType(Value d) { auto dTy = cast(d.getType()).getElementType(); @@ -68,7 +68,7 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { } } -int64_t getSwizzlingFromLayout(const SharedEncodingAttr &layout, +int64_t getSwizzlingFromLayout(const SharedEncodingTrait &layout, uint32_t widthInByte) { int perPhase = layout.getPerPhase(); int maxPhase = layout.getMaxPhase(); @@ -127,7 +127,7 @@ mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader( instrShape(instrShape), elemBits(elementBitwidth) { auto b = TritonLLVMOpBuilder(loc, rewriter); auto ty = cast(tensor.getType()); - auto sharedLayout = cast(ty.getEncoding()); + auto sharedLayout = cast(ty.getEncoding()); ord = sharedLayout.getOrder(); const int perPhase = sharedLayout.getPerPhase(); const int maxPhase = sharedLayout.getMaxPhase(); @@ -179,7 +179,7 @@ DotOpMmaV3SmemLoader loadA(const LLVMTypeConverter *typeConverter, Value tensor, Value smemObjBase, Value thread) { auto b = TritonLLVMOpBuilder(loc, rewriter); auto aTy = cast(tensor.getType()); - auto aSharedLayout = dyn_cast(aTy.getEncoding()); + auto aSharedLayout = dyn_cast(aTy.getEncoding()); assert(aSharedLayout && "only support load dot operand from shared."); auto instrShape = mmaEncoding.getInstrShape(); auto wpt = mmaEncoding.getWarpsPerCTA(); @@ -215,7 +215,7 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, Value base, Value thread) { auto b = TritonLLVMOpBuilder(loc, rewriter); auto bTy = cast(tensor.getType()); - auto bSharedLayout = cast(bTy.getEncoding()); + auto bSharedLayout = cast(bTy.getEncoding()); assert(bSharedLayout && "only support load B from shared."); auto instrShape = mmaEncoding.getInstrShape(); auto wpt = mmaEncoding.getWarpsPerCTA(); @@ -365,8 +365,8 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto aTensorTy = cast(a.getType()); auto bTensorTy = cast(b.getType()); auto dTensorTy = cast(d.getType()); - auto aSharedLayout = dyn_cast(aTensorTy.getEncoding()); - auto bSharedLayout = cast(bTensorTy.getEncoding()); + auto aSharedLayout = dyn_cast(aTensorTy.getEncoding()); + auto bSharedLayout = cast(bTensorTy.getEncoding()); auto mmaEncoding = cast(dTensorTy.getEncoding()); auto bOrd = bSharedLayout.getOrder(); bool transA = false; @@ -515,9 +515,9 @@ LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, ConversionPatternRewriter &rewriter, Value thread) { auto AEnc = op.getA().getType().getEncoding(); auto BEnc = op.getB().getType().getEncoding(); - assert(mlir::isa(AEnc) || + assert(mlir::isa(AEnc) || mlir::isa(AEnc)); - assert(mlir::isa(BEnc) && + assert(mlir::isa(BEnc) && "Operand B should use Shared layout."); return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // op.getA(), op.getB(), op.getC(), op.getD(), op.getUseC(), // diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index ec59b9681521..d34d643e880e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -481,7 +481,7 @@ struct TensorMemoryCopyOpConversion auto b = TritonLLVMOpBuilder(loc, rewriter); auto srcTy = cast(op.getSrc().getType()); assert(isa(srcTy.getMemorySpace())); - assert(isa(srcTy.getEncoding())); + assert(isa(srcTy.getEncoding())); auto sharedEnc = cast(srcTy.getEncoding()); assert( From 825288b93c75ca1309a6bc3d5a039392314a7176 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 13:25:40 +0900 Subject: [PATCH 07/17] replacing SharedEncodingAttr -> SharedEncodingTrait --- bin/triton-tensor-layout.cpp | 2 +- .../triton/Conversion/TritonGPUToLLVM/Utility.h | 2 +- .../Dialect/TritonGPU/IR/LinearLayoutConversions.h | 4 ++-- .../triton/Dialect/TritonGPU/Transforms/Utility.h | 2 +- lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp | 2 -- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 2 +- lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp | 1 - .../TritonGPU/IR/LinearLayoutConversions.cpp | 6 +++--- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 14 ++++++++++---- .../Transforms/Pipeliner/MatmulLoopPipeline.cpp | 4 ++-- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- .../TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp | 10 +++++----- 12 files changed, 27 insertions(+), 24 deletions(-) diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 033f5e688029..90368de76830 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -84,7 +84,7 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { // toLinearLayout interface. mlir::Attribute layout = tensorType.getEncoding(); if (isa(layout)) { + mlir::triton::gpu::SharedEncodingTrait>(layout)) { os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); return success(); } diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index f06021f6a242..e79473b767da 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1288,7 +1288,7 @@ inline Value packLLVector(Location loc, ValueRange vals, inline bool isSimpleSharedMemoryAccess(ArrayRef shape, ArrayRef allocShape, - triton::gpu::SharedEncodingAttr sharedEnc) { + triton::gpu::SharedEncodingTrait sharedEnc) { auto rank = shape.size(); return /*no swizzling*/ sharedEnc.getMaxPhase() == 1 || /*swizzling but same shape*/ shape == allocShape || diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 9fa0592c4803..cf5c05f4a91b 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -9,7 +9,7 @@ #include "triton/Tools/LinearLayout.h" namespace mlir::triton::gpu { -class SharedEncodingAttr; +class SharedEncodingTrait; // - BlockedEncodingAttrs have the following input dimensions. // @@ -48,7 +48,7 @@ LinearLayout toLinearLayout(ArrayRef shape, Attribute layout, // If `disableSwizzle` is set, then the resulting layout does not include // swizzling. LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, - SharedEncodingAttr shared, + SharedEncodingTrait shared, int32_t elemBitWidth, bool disableSwizzle = false); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index f9b8ae2d97c0..6212cfaa67cf 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -194,7 +194,7 @@ bool isPureUnaryInlineAsm(Operation *op); // read the compute capability from the module attributes int getNVIDIAComputeCapability(Operation *module); -std::optional +std::optional getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible); enum class MMALoadType { diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 4f45cfaccac8..cb358b8f1cfc 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -76,8 +76,6 @@ struct LocalAllocOpConversion LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto resultTy = cast(op.getType()); auto typeConverter = getTypeConverter(); - auto sharedLayout = - cast(resultTy.getEncoding()); auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(), diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 5d779578ce37..cbc2abd83ac9 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -208,7 +208,7 @@ Value getSmemVecAddr(const LinearLayout ®Layout, auto allocShape = sharedTy.getAllocShape(); auto rank = shape.size(); auto sharedEnc = - dyn_cast(sharedTy.getEncoding()); + dyn_cast(sharedTy.getEncoding()); auto smemBase = smemObj.getBase(); auto smemOffsets = smemObj.getOffsets(); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 5dfd10077b5d..978111e8f547 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -279,7 +279,6 @@ struct MemDescTransOpConversion ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto resultTy = cast(op.getType()); - auto enc = cast(resultTy.getEncoding()); auto llvmElemTy = getTypeConverter()->convertType(resultTy.getElementType()); auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index b72635350061..9c10e4d35dba 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -119,7 +119,7 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, } LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, - SharedEncodingAttr shared) { + SharedEncodingTrait shared) { assert(!shared.getHasLeadingOffset()); MLIRContext *ctx = shared.getContext(); @@ -169,7 +169,7 @@ LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, } // namespace LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, - SharedEncodingAttr shared, + SharedEncodingTrait shared, int32_t elemBitWidth, bool disableSwizzle) { assert(shared.getHasLeadingOffset()); @@ -888,7 +888,7 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, if (auto distributed = dyn_cast(layout)) { result = distributed.toLinearLayout(shape); } else { - auto shared = dyn_cast(layout); + auto shared = dyn_cast(layout); if (shared.getHasLeadingOffset()) { assert(elemBitWidth.has_value()); result = sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4f6a59093eb3..c37b59ed0573 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -10,6 +10,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/ErrorHandling.h" #include namespace mlir { @@ -331,7 +332,6 @@ class FuseTransMMAV3Plus : public OpRewritePattern { return failure(); MemDescType allocType = allocOp.getType(); - auto allocEncoding = cast(allocType.getEncoding()); RankedTensorType srcTy = trans.getSrc().getType(); // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 @@ -352,9 +352,15 @@ class FuseTransMMAV3Plus : public OpRewritePattern { // TODO(Qingyi): need to check whether the CTALayout of innerCvtEnc should // be used here. For tests where numCTAs = 1, this is not a problem since // all CTALayouts are the same. - auto newInnerEnc = SharedEncodingAttr::get( - getContext(), srcTy.getShape(), newInnerCvtOrder, - allocEncoding.getCTALayout(), srcTy.getElementType()); + SharedEncodingTrait newInnerEnc; + if (auto allocEncoding = + dyn_cast(allocType.getEncoding())) { + newInnerEnc = SharedEncodingAttr::get( + getContext(), srcTy.getShape(), newInnerCvtOrder, + allocEncoding.getCTALayout(), srcTy.getElementType()); + } else { + llvm_unreachable("Unsupported shared encoding"); + } MemDescType innerTy = MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc, diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 3ff045937b05..6b629b07e41c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -37,7 +37,7 @@ namespace { struct LoadInfo { // Layout of the data in shared memory. - ttg::SharedEncodingAttr sharedEncoding = nullptr; + ttg::SharedEncodingTrait sharedEncoding = nullptr; // Blocked encoding is used for loads not used by the dot. ttg::BlockedEncodingAttr blockedEncoding = nullptr; bool isMMAv3Shared = false; @@ -578,7 +578,7 @@ assignMemoryLayouts(scf::ForOp &forOp, // Create an allocation that can hold distance number of loadOp shapes. static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, - ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + ttg::SharedEncodingTrait sharedEnc, unsigned distance) { OpBuilder builder(forOp); Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(forOp.getContext()); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 2b4aa98ad647..5b46ce1524b5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -985,7 +985,7 @@ int getNVIDIAComputeCapability(Operation *module) { // the same dot operand encoding, return the shared encoding that needs to be // used to be compatible with users' layouts. If there are incompatible shared // encodings, set incompatible to true. -std::optional +std::optional getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { ttg::SharedEncodingAttr attr; incompatible = false; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index c57c3bf1824a..1795271ddc71 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -14,7 +14,7 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; -using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingTrait; mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader( Value tensor, Value base, SmallVector instrShape, @@ -344,12 +344,12 @@ void convertDot(const LLVMTypeConverter *typeConverter, bool aInTmem = true; bool transA = false; if (auto aSharedLayout = - dyn_cast(aTensorTy.getEncoding())) { + dyn_cast(aTensorTy.getEncoding())) { auto aOrd = aSharedLayout.getOrder(); transA = aOrd[0] == 0; aInTmem = false; } - auto bSharedLayout = cast(bTensorTy.getEncoding()); + auto bSharedLayout = cast(bTensorTy.getEncoding()); auto bOrd = bSharedLayout.getOrder(); bool transB = bOrd[0] == 1; Value baseA = @@ -491,12 +491,12 @@ struct TCGen5MMAScaledOpConversion bool aInTmem = true; bool transA = false; if (auto aSharedLayout = - dyn_cast(aTensorTy.getEncoding())) { + dyn_cast(aTensorTy.getEncoding())) { auto aOrd = aSharedLayout.getOrder(); transA = aOrd[0] == 0; aInTmem = false; } - auto bSharedLayout = cast(bTensorTy.getEncoding()); + auto bSharedLayout = cast(bTensorTy.getEncoding()); auto bOrd = bSharedLayout.getOrder(); bool transB = bOrd[0] == 1; Value baseA = From 78e29dc2c86d62ecf4c60e976bcc64d1ea9e7703 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 13:47:05 +0900 Subject: [PATCH 08/17] more --- .../SharedToDotOperandFMA.cpp | 10 +++---- .../TritonGPUToLLVM/TypeConverter.cpp | 1 - lib/Dialect/TritonGPU/IR/Dialect.cpp | 28 +++++++++++-------- lib/Dialect/TritonGPU/IR/Ops.cpp | 4 +-- .../Transforms/CoalesceAsyncCopy.cpp | 2 +- .../Transforms/Pipeliner/AssignLatencies.cpp | 4 +-- .../Pipeliner/MatmulLoopPipeline.cpp | 24 ---------------- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 6 ++-- .../ConvertLayoutOpToLLVM.cpp | 1 - .../lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 1 - .../DecomposeUnsupportedConversions.cpp | 25 +++++++++++------ .../LoadStoreOpToLLVM.cpp | 8 +++--- .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 8 +++--- 13 files changed, 55 insertions(+), 67 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index 6dbbe0a110ef..ac6ec2501c71 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -15,7 +15,7 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::MemDescType; -using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingTrait; Value getStructFromValueTable(ArrayRef vals, ConversionPatternRewriter &rewriter, Location loc, @@ -32,11 +32,11 @@ Value getStructFromValueTable(ArrayRef vals, return packLLElements(loc, typeConverter, elems, rewriter, structTy); } -bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; } +bool isSwizzled(SharedEncodingTrait layout) { return layout.getMaxPhase() != 1; } SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, Location loc, SmallVector rawIndices, - SharedEncodingAttr layout) { + SharedEncodingTrait layout) { auto b = TritonLLVMOpBuilder(loc, rewriter); const auto &order = layout.getOrder(); auto rank = order.size(); @@ -154,7 +154,7 @@ Value computeSwizzledOffset(ConversionPatternRewriter &rewriter, Location loc, Value bTileOffset, Value nonKTileOffset, unsigned shapePerCTABTile, unsigned shapePerCTANonKTile, - SharedEncodingAttr sharedLayout, + SharedEncodingTrait sharedLayout, ArrayRef opTensorShape, ArrayRef strides) { auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -229,7 +229,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, dim.nonK = dotOpNo == 0 ? 1 : 2; auto opTensorTy = cast(srcVal.getType()); auto opTensorShape = expandMatrixShapeWithBatch(opTensorTy.getShape()); - auto sharedLayout = cast(opTensorTy.getEncoding()); + auto sharedLayout = cast(opTensorTy.getEncoding()); auto opOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 092501ce10ef..bb658b2aa716 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -10,7 +10,6 @@ using namespace mlir::triton; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::MemDescType; -using ::mlir::triton::gpu::SharedEncodingAttr; TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( MLIRContext *ctx, LowerToLLVMOptions &options, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index e143444d5721..c3ce3088c684 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -284,7 +284,7 @@ SmallVector getOrder(Attribute layout) { } return order; } - if (auto sharedLayout = mlir::dyn_cast(layout)) { + if (auto sharedLayout = mlir::dyn_cast(layout)) { return llvm::to_vector(sharedLayout.getOrder()); } if (auto linearLayout = mlir::dyn_cast(layout)) { @@ -378,12 +378,12 @@ SmallVector getShapePerCTA(ArrayRef CTASplitNum, } SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { - if (auto sharedLayout = mlir::dyn_cast(layout)) { + if (auto sharedLayout = mlir::dyn_cast(layout)) { // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. // The first dim of shape is numStages. This is a work around, otherwise // too many places would have to be modified in pipeline pass. Maybe we // need to refactor this logic in the future. - auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + auto CTASplitNum = sharedLayout.getCTASplitNum(); if (shape.size() == CTASplitNum.size() + 1) { auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); res.insert(res.begin(), shape.front()); @@ -421,8 +421,8 @@ unsigned getNumWarpsPerCTA(Attribute layout) { warpsPerCTA = wmmaLayout.getWarpsPerCTA(); else if (auto dotLayout = dyn_cast(layout)) warpsPerCTA = dotLayout.getWarpsPerCTA(); - else if (auto sharedLayout = dyn_cast(layout)) - llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); + else if (auto sharedLayout = dyn_cast(layout)) + llvm::report_fatal_error("Cannot get numWarps from SharedEncodingTrait"); else llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA"); return product(warpsPerCTA); @@ -2423,6 +2423,8 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { } else if (auto sharedAttr = mlir::dyn_cast(attr)) { os << "shared"; return AliasResult::FinalAlias; + } else if (auto sharedAttr = mlir::dyn_cast(attr)) { + llvm_unreachable("Unsupported shared encoding"); } else if (auto blockedAttr = mlir::dyn_cast(attr)) { os << "blocked"; return AliasResult::FinalAlias; @@ -2515,18 +2517,22 @@ struct TritonGPUInferLayoutInterface applyPermutation(invOrderUnsigned, layout.getCTAOrder())); }; - if (auto enc = mlir::dyn_cast(operandEncoding)) { + if (auto enc = mlir::dyn_cast(operandEncoding)) { if (enc.getOrder().size() != order.size()) { return failure(); } - FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + FailureOr ctaLayout = permuteCTALayout(getCTALayout(enc)); if (failed(ctaLayout)) { return failure(); } - resultEncoding = SharedEncodingAttr::get( - ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(), - applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout, - enc.getHasLeadingOffset()); + if (isa(enc)) { + resultEncoding = SharedEncodingAttr::get( + ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(), + applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout, + enc.getHasLeadingOffset()); + } else { + llvm_unreachable("Unsupported shared encoding"); + } return success(); } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 5e3566f47c83..2fa6e74a5d7f 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -608,11 +608,11 @@ LogicalResult MemDescSubviewOp::verify() { if (!isa(srcEnc) && !isa(srcEnc)) { - return emitError("src encoding must be SharedEncodingAttr"); + return emitError("src encoding must be SharedEncodingTrait"); } if (!isa(dstEnc) && !isa(srcEnc)) { - return emitError("result encoding must be SharedEncodingAttr"); + return emitError("result encoding must be SharedEncodingTrait"); } if (isa(srcEnc)) { diff --git a/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp index 1266101beea3..b51a004724d7 100644 --- a/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp +++ b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp @@ -45,7 +45,7 @@ struct ClipAsyncCopySizePerThread if (!blockedEnc) return rewriter.notifyMatchFailure(copyOp, "src must be of blocked encoding"); - auto sharedEnc = cast(dstTy.getEncoding()); + auto sharedEnc = cast(dstTy.getEncoding()); auto sharedVec = sharedEnc.getVec(); // obtain max contiguous copy size diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp index 62face176f55..92154e26516c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -45,12 +45,12 @@ bool canHaveSharedEncoding(tt::LoadOp op) { if (llvm::any_of(op->getUsers(), [&](Operation *user) { return isa(user); })) { - ttg::SharedEncodingAttr localAllocEnc; + ttg::SharedEncodingTrait localAllocEnc; for (auto user : op->getUsers()) { auto localAlloc = dyn_cast(user); if (!localAlloc) continue; - auto enc = mlir::cast( + auto enc = mlir::cast( localAlloc.getType().getEncoding()); if (!localAllocEnc) { localAllocEnc = enc; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 6b629b07e41c..d80299317300 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -402,30 +402,6 @@ getSharedEncoding(Operation *loadOp, bool isTMALoad) { ctaLayout); } -static bool hasSharedEncodingHelper(Operation *loadOp) { - // If the load is used by a LocalAllocOp, use the same encoding as the allocs. - // If the allocs don't all have the same encoding, bail. - if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) { - return isa(user); - })) { - ttg::SharedEncodingAttr localAllocEnc; - for (auto user : loadOp->getUsers()) { - auto localAlloc = dyn_cast(user); - if (!localAlloc) - continue; - auto enc = mlir::cast( - localAlloc.getType().getEncoding()); - if (!localAllocEnc) { - localAllocEnc = enc; - } - if (enc != localAllocEnc) - return false; - } - return true; - } - return true; -} - static llvm::SmallVector getDirectUserInBlock(Operation *loadOp) { llvm::SmallVector users; DenseSet seen; diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 5b46ce1524b5..5d00fab838cf 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -987,17 +987,17 @@ int getNVIDIAComputeCapability(Operation *module) { // encodings, set incompatible to true. std::optional getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { - ttg::SharedEncodingAttr attr; + ttg::SharedEncodingTrait attr; incompatible = false; for (Operation *user : val.getUsers()) { - ttg::SharedEncodingAttr tempAttr; + ttg::SharedEncodingTrait tempAttr; if (user->getNumResults() != 1) return std::nullopt; if (auto memDesc = dyn_cast(user->getResult(0).getType())) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. - tempAttr = cast(memDesc.getEncoding()); + tempAttr = cast(memDesc.getEncoding()); if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) .has_value()) return std::nullopt; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 59a3f25648ba..c4f3c46782b1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -8,7 +8,6 @@ using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MemDescType; -using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandMFMA { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp index 3b212f9d9660..a7390649f1e4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -8,7 +8,6 @@ using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MemDescType; -using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandMFMA { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index d9c4909c421c..a1dd0cf6f451 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -4,7 +4,10 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/Patterns.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; @@ -35,7 +38,7 @@ class DecomposeLocalLoadToDotOperand auto dstDotOp = dyn_cast( op.getType().getEncoding()); MemDescType srcType = op.getSrc().getType(); - auto sharedEncoding = dyn_cast(srcType.getEncoding()); + auto sharedEncoding = dyn_cast(srcType.getEncoding()); if (!dstDotOp || !sharedEncoding || !sharedEncoding.getHasLeadingOffset()) return failure(); RankedTensorType type = op.getType(); @@ -50,13 +53,19 @@ class DecomposeLocalLoadToDotOperand blockEncoding); Value load = rewriter.create(op.getLoc(), tmpType, op.getSrc()); - auto newSharedDescTy = MemDescType::get( - type.getShape(), type.getElementType(), - triton::gpu::SharedEncodingAttr::get( - op.getContext(), dstDotOp, type.getShape(), - triton::gpu::getOrder(parentEnc), - triton::gpu::getCTALayout(parentEnc), type.getElementType()), - srcType.getMemorySpace()); + + SharedEncodingTrait newSharedEnc; + if (isa(sharedEncoding)) { + newSharedEnc = triton::gpu::SharedEncodingAttr::get( + op.getContext(), dstDotOp, type.getShape(), + triton::gpu::getOrder(parentEnc), + triton::gpu::getCTALayout(parentEnc), type.getElementType()); + } else { + llvm_unreachable("Unsupporeted shared encoding"); + } + auto newSharedDescTy = + MemDescType::get(type.getShape(), type.getElementType(), newSharedEnc, + srcType.getMemorySpace()); auto tmp = rewriter.create( op.getLoc(), newSharedDescTy, load); auto newConvert = diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index e19e2b0ecd1b..741c06a583c1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -30,7 +30,7 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getCTALayout; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingTrait; namespace ttg = mlir::triton::gpu; @@ -1047,7 +1047,7 @@ struct AsyncCopyGlobalToLocalOpConversion auto srcLayout = srcTy.getEncoding(); assert((isa(srcLayout) && "Unexpected srcLayout in AsyncCopyGlobalToLocalOpConversion")); - auto resSharedLayout = cast(dstTy.getEncoding()); + auto resSharedLayout = cast(dstTy.getEncoding()); Value llDst = adaptor.getResult(); Value llSrc = adaptor.getSrc(); @@ -1361,7 +1361,7 @@ struct AsyncTMACopyLocalToGlobalOpConversion static LinearLayout getUnswizzledLayout(triton::gpu::MemDescType type) { return triton::gpu::sharedToLinearLayoutLeadingOffset( - type.getShape(), cast(type.getEncoding()), + type.getShape(), cast(type.getEncoding()), type.getElementTypeBitWidth(), /*disableSwizzle=*/true); } @@ -1412,7 +1412,7 @@ static LogicalResult iterateGatherScatterIndices( return op->emitError("memdesc shape must match alloc shape"); // `hasLeadingOffset` means the core matrix tiles are placed next to each // other in shared memory, which lines up with how `gather4` loads data. - if (!cast(smemType.getEncoding()).getHasLeadingOffset()) + if (!cast(smemType.getEncoding()).getHasLeadingOffset()) return op->emitError("requires dst encoding with `hasLeadingOffset=true`"); Type llvmElemTy = typeConverter.convertType(smemType.getElementType()); Type elemPtrTy = ptr_ty(ctx, /*addrspace=*/3); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index d16de7ae2d34..d90803580168 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -36,7 +36,7 @@ struct LocalLoadOpConversion cast(dstLayout).getParent())) { auto dotEnc = cast(dstLayout); auto mmaEnc = cast(dotEnc.getParent()); - auto sharedEnc = cast(srcLayout); + auto sharedEnc = cast(srcLayout); auto bitwidth = dstTy.getElementTypeBitWidth(); auto vecWidth = 32 / bitwidth; auto kWidth = dotEnc.getKWidth(); @@ -77,7 +77,7 @@ struct LocalLoadOpConversion auto dstTy = cast(op.getType()); auto srcTy = cast(op.getSrc().getType()); auto dotEnc = cast(dstTy.getEncoding()); - auto sharedEnc = cast(srcTy.getEncoding()); + auto sharedEnc = cast(srcTy.getEncoding()); auto shape = dstTy.getShape(); auto rank = dstTy.getRank(); auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; @@ -146,7 +146,7 @@ LogicalResult lowerDistributedToSharedStmatrix( if (!mmaEncoding) return failure(); auto sharedLayout = - cast(memDescType.getEncoding()); + cast(memDescType.getEncoding()); if (!sharedLayout.getHasLeadingOffset()) return failure(); int swizzleByteSize = 0; @@ -223,7 +223,7 @@ struct LocalAllocOpConversion return failure(); MemDescType memDescType = op.getType(); auto sharedLayout = - cast(memDescType.getEncoding()); + cast(memDescType.getEncoding()); RankedTensorType srcTy = op.getSrc().getType(); Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); Value smemBase = From 60a9d89d4e00bf996473f22bba0b3c951af7e461 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 13:55:38 +0900 Subject: [PATCH 09/17] more --- include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 9 +++++---- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 ++--- lib/Dialect/TritonGPU/IR/Ops.cpp | 2 +- .../Transforms/Pipeliner/MatmulLoopPipeline.cpp | 6 +++--- .../TritonNvidiaGPU/Transforms/FenceInsertion.cpp | 2 -- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index a096abe4c6e3..81fed21d7edd 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -177,6 +177,9 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", "SmallVector", "getCTAOrder">, + InterfaceMethod<"", + "int32_t", + "getAlignment">, InterfaceMethod<"", "bool", "getHasLeadingOffset"> @@ -199,6 +202,7 @@ class SharedEncoding traits = [], SmallVector getCTASplitNum() const; SmallVector getCTAOrder() const; bool getHasLeadingOffset() const; + int32_t getAlignment() const; }]; } @@ -485,10 +489,7 @@ compared to 1*64 when the hasLeadingOffset is false. }]> ]; - let extraClassDeclaration = extraSharedEncodingDeclaration # [{ - // TODO: Should this be part of the interface? - int32_t getAlignment() const; - }]; + let extraClassDeclaration = extraSharedEncodingDeclaration; let hasCustomAssemblyFormat = 1; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c3ce3088c684..e2ec9e4aa5a7 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -3387,10 +3387,9 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, // tensorType is needed later on (e.g., getDimSize(j)), so we still have to // pass it as a param - if (auto sharedLayout = mlir::dyn_cast(layout)) { + if (isa(layout)) { return getSharedLayoutStr(tensorType, useHWPointOfView); - } else if (auto distributedLayout = - mlir::dyn_cast(layout)) { + } else if (isa(layout)) { return getDistributedLayoutStr(tensorType, useHWPointOfView); } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 2fa6e74a5d7f..d37d3a8ca443 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -654,7 +654,7 @@ int32_t LocalAllocOp::getAlignmentOrDefault() { } auto ty = getType(); - auto enc = dyn_cast(ty.getEncoding()); + auto enc = dyn_cast(ty.getEncoding()); return enc ? enc.getAlignment() : 16; } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index d80299317300..4019707f2cef 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -351,7 +351,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { threadsPerWarp, ctaLayout); } -static std::optional +static std::optional getSharedEncoding(Operation *loadOp, bool isTMALoad) { auto ty = cast(loadOp->getResultTypes()[0]); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); @@ -380,12 +380,12 @@ getSharedEncoding(Operation *loadOp, bool isTMALoad) { if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) { return isa(user); })) { - ttg::SharedEncodingAttr localAllocEnc; + ttg::SharedEncodingTrait localAllocEnc; for (auto user : loadOp->getUsers()) { auto localAlloc = dyn_cast(user); if (!localAlloc) continue; - auto enc = mlir::cast( + auto enc = mlir::cast( localAlloc.getType().getEncoding()); if (!localAllocEnc) { localAllocEnc = enc; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fdb18..89545b9712f5 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -21,8 +21,6 @@ namespace ttng = ::mlir::triton::nvidia_gpu; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" -using ::mlir::triton::gpu::SharedEncodingAttr; - namespace { struct FenceInsertionPass From 2a81d6cdfacede7acebc26569fee26873f4e0884 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 13:59:29 +0900 Subject: [PATCH 10/17] pre-commit --- .../ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp | 4 +++- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index ac6ec2501c71..ff389f9851de 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -32,7 +32,9 @@ Value getStructFromValueTable(ArrayRef vals, return packLLElements(loc, typeConverter, elems, rewriter, structTy); } -bool isSwizzled(SharedEncodingTrait layout) { return layout.getMaxPhase() != 1; } +bool isSwizzled(SharedEncodingTrait layout) { + return layout.getMaxPhase() != 1; +} SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, Location loc, SmallVector rawIndices, diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 4019707f2cef..1ea0afa2c74f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -554,7 +554,8 @@ assignMemoryLayouts(scf::ForOp &forOp, // Create an allocation that can hold distance number of loadOp shapes. static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, - ttg::SharedEncodingTrait sharedEnc, unsigned distance) { + ttg::SharedEncodingTrait sharedEnc, + unsigned distance) { OpBuilder builder(forOp); Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(forOp.getContext()); From c0a7431bd52f2dc9a6d05338f182f2efbc168401 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Jan 2025 14:35:32 +0900 Subject: [PATCH 11/17] getHasLeadingOffset -> hasLeadingOffset --- include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 4 ++-- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 8 ++++---- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 6 +++--- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- .../DecomposeUnsupportedConversions.cpp | 2 +- .../lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- .../nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 4 ++-- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 81fed21d7edd..7d94a3155f32 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -182,7 +182,7 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { "getAlignment">, InterfaceMethod<"", "bool", - "getHasLeadingOffset"> + "hasLeadingOffset"> ]; } @@ -201,7 +201,7 @@ class SharedEncoding traits = [], SmallVector getCTAsPerCGA() const; SmallVector getCTASplitNum() const; SmallVector getCTAOrder() const; - bool getHasLeadingOffset() const; + bool hasLeadingOffset() const; int32_t getAlignment() const; }]; } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index cbc2abd83ac9..3e8fe3fedf15 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -253,7 +253,7 @@ Value getSmemVecAddr(const LinearLayout ®Layout, .second; } else { // Case 2 -> rank-reduced swizzling assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2"); - assert(!sharedEnc.getHasLeadingOffset() && + assert(!sharedEnc.hasLeadingOffset() && "Leading offsets are not supported for sliced tensors"); // We define both tensor offsets and shared memory offsets: // diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index e2ec9e4aa5a7..47354e6ca7ab 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -909,12 +909,12 @@ SmallVector SharedEncodingAttr::getCTASplitNum() const { return SmallVector(getCTALayout().getCTASplitNum()); } -bool SharedEncodingAttr::getHasLeadingOffset() const { +bool SharedEncodingAttr::hasLeadingOffset() const { return getHasLeadingOffset__(); } int32_t SharedEncodingAttr::getAlignment() const { - if (getHasLeadingOffset()) + if (hasLeadingOffset()) return 128 * getMaxPhase(); return 16; } @@ -1903,7 +1903,7 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { << ", order = [" << ArrayRef(getOrder()) << "]"; maybePrintCTALayout(getContext(), printer, getCTALayout(), /*rank=*/getOrder().size()); - printer << ", hasLeadingOffset = " << getHasLeadingOffset() << "}>"; + printer << ", hasLeadingOffset = " << hasLeadingOffset() << "}>"; } //===----------------------------------------------------------------------===// @@ -2529,7 +2529,7 @@ struct TritonGPUInferLayoutInterface resultEncoding = SharedEncodingAttr::get( ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(), applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout, - enc.getHasLeadingOffset()); + enc.hasLeadingOffset()); } else { llvm_unreachable("Unsupported shared encoding"); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 9c10e4d35dba..e91398c9db82 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -120,7 +120,7 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, SharedEncodingTrait shared) { - assert(!shared.getHasLeadingOffset()); + assert(!shared.hasLeadingOffset()); MLIRContext *ctx = shared.getContext(); int rank = shape.size(); @@ -172,7 +172,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, SharedEncodingTrait shared, int32_t elemBitWidth, bool disableSwizzle) { - assert(shared.getHasLeadingOffset()); + assert(shared.hasLeadingOffset()); MLIRContext *ctx = shared.getContext(); int rank = shape.size(); @@ -889,7 +889,7 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, result = distributed.toLinearLayout(shape); } else { auto shared = dyn_cast(layout); - if (shared.getHasLeadingOffset()) { + if (shared.hasLeadingOffset()) { assert(elemBitWidth.has_value()); result = sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); } else { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 5d00fab838cf..363d0cc8a8af 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1035,7 +1035,7 @@ MMALoadType getMMALoadType(Operation *loadOp) { auto sharedEnc = cast(alloc.getType().getEncoding()); - if (!sharedEnc.getHasLeadingOffset()) + if (!sharedEnc.hasLeadingOffset()) return MMALoadType::DoNotPipeline; // MMA V3 case. diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index a1dd0cf6f451..bb82670904da 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -39,7 +39,7 @@ class DecomposeLocalLoadToDotOperand op.getType().getEncoding()); MemDescType srcType = op.getSrc().getType(); auto sharedEncoding = dyn_cast(srcType.getEncoding()); - if (!dstDotOp || !sharedEncoding || !sharedEncoding.getHasLeadingOffset()) + if (!dstDotOp || !sharedEncoding || !sharedEncoding.hasLeadingOffset()) return failure(); RankedTensorType type = op.getType(); auto parentEnc = dstDotOp.getParent(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 741c06a583c1..bdf639f25bde 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1412,7 +1412,7 @@ static LogicalResult iterateGatherScatterIndices( return op->emitError("memdesc shape must match alloc shape"); // `hasLeadingOffset` means the core matrix tiles are placed next to each // other in shared memory, which lines up with how `gather4` loads data. - if (!cast(smemType.getEncoding()).getHasLeadingOffset()) + if (!cast(smemType.getEncoding()).hasLeadingOffset()) return op->emitError("requires dst encoding with `hasLeadingOffset=true`"); Type llvmElemTy = typeConverter.convertType(smemType.getElementType()); Type elemPtrTy = ptr_ty(ctx, /*addrspace=*/3); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index d90803580168..79548bf56e9c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -47,7 +47,7 @@ struct LocalLoadOpConversion // Limitation 1 [TODO: remove]: Check LL bases to verify register and // address alignment auto canUseLdmatrix = - (kWidth == vecWidth) && (!sharedEnc.getHasLeadingOffset()); + (kWidth == vecWidth) && (!sharedEnc.hasLeadingOffset()); canUseLdmatrix &= (sharedEnc.getMaxPhase() == 1) || (sharedEnc.getVec() * bitwidth >= 8 * 16); auto shape = srcTy.getShape(); @@ -147,7 +147,7 @@ LogicalResult lowerDistributedToSharedStmatrix( return failure(); auto sharedLayout = cast(memDescType.getEncoding()); - if (!sharedLayout.getHasLeadingOffset()) + if (!sharedLayout.hasLeadingOffset()) return failure(); int swizzleByteSize = 0; if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) From 8022d430d37b619360cffa1ddcf24091e3ad18f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Jan 2025 10:43:57 +0900 Subject: [PATCH 12/17] template name and attrMnenoic --- include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 7d94a3155f32..4ba3d00c472c 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -186,9 +186,9 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { ]; } -class SharedEncoding traits = [], +class SharedEncoding traits = [], Dialect dialect = TritonGPU_Dialect> - : TritonGPU_Attr<"SharedEncoding", "shared_encoding", !listconcat([SharedEncodingTrait], traits), dialect> { + : TritonGPU_Attr { let description = [{ }]; @@ -206,8 +206,7 @@ class SharedEncoding traits = [], }]; } -// def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> { -def SharedEncodingAttr : SharedEncoding { +def SharedEncodingAttr : SharedEncoding<"SharedEncoding", "shared_encoding"> { let mnemonic = "shared"; let description = [{ From 2b3f3009553be987f6961d7dc829d2a3ed603883 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Jan 2025 10:50:00 +0900 Subject: [PATCH 13/17] minor --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 363d0cc8a8af..9b1479671438 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1033,7 +1033,7 @@ MMALoadType getMMALoadType(Operation *loadOp) { if (auto alloc = dyn_cast(*loadOp->getUsers().begin())) { auto sharedEnc = - cast(alloc.getType().getEncoding()); + cast(alloc.getType().getEncoding()); if (!sharedEnc.hasLeadingOffset()) return MMALoadType::DoNotPipeline; From 3698ffc55c10f964fb08890f1a8deb27c6892a9c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Jan 2025 16:48:41 +0900 Subject: [PATCH 14/17] add desc to td file --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 116 +++++++++--------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 4ba3d00c472c..73413c11e59e 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -152,63 +152,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { let cppNamespace = "::mlir::triton::gpu"; - let description = [{ - }]; - - let methods = [ - InterfaceMethod<"", - "unsigned", - "getVec">, - InterfaceMethod<"", - "unsigned", - "getPerPhase">, - InterfaceMethod<"", - "unsigned", - "getMaxPhase">, - InterfaceMethod<"", - "SmallVector", - "getOrder">, - InterfaceMethod<"Get the shape of the CTAs per CGA.", - "SmallVector", - "getCTAsPerCGA">, - InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", - "SmallVector", - "getCTASplitNum">, - InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", - "SmallVector", - "getCTAOrder">, - InterfaceMethod<"", - "int32_t", - "getAlignment">, - InterfaceMethod<"", - "bool", - "hasLeadingOffset"> - ]; -} - -class SharedEncoding traits = [], - Dialect dialect = TritonGPU_Dialect> - : TritonGPU_Attr { - - let description = [{ - }]; - - code extraSharedEncodingDeclaration = extraBaseClassDeclaration # [{ - unsigned getVec() const; - unsigned getPerPhase() const; - unsigned getMaxPhase() const; - SmallVector getOrder() const; - SmallVector getCTAsPerCGA() const; - SmallVector getCTASplitNum() const; - SmallVector getCTAOrder() const; - bool hasLeadingOffset() const; - int32_t getAlignment() const; - }]; -} - -def SharedEncodingAttr : SharedEncoding<"SharedEncoding", "shared_encoding"> { - let mnemonic = "shared"; - let description = [{ An encoding for tensors whose elements may be simultaneously accessed by different cuda threads in the programs, via shared memory. In other words, @@ -291,6 +234,65 @@ a matrix of size 16x128 and data type I8 is stored in the shared memory with compared to 1*64 when the hasLeadingOffset is false. }]; + let methods = [ + InterfaceMethod<"Get 'vec', one of properties of swizzling.", + "unsigned", + "getVec">, + InterfaceMethod<"Get 'per phase', one of properties of swizzling.", + "unsigned", + "getPerPhase">, + InterfaceMethod<"Get 'max phase', one of properties of swizzling.", + "unsigned", + "getMaxPhase">, + InterfaceMethod<"Get the order of this SMEM encoding.", + "SmallVector", + "getOrder">, + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA">, + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum">, + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder">, + InterfaceMethod<"Get the size of the address alignment in bytes.", + "int32_t", + "getAlignment">, + InterfaceMethod<"True if the SMEM layout is in the core-matrices format.", + "bool", + "hasLeadingOffset"> + ]; +} + +class SharedEncoding traits = [], + Dialect dialect = TritonGPU_Dialect> + : TritonGPU_Attr { + + let description = [{ +The interface which can be inherited by various representations of an SMEM encoding. + }]; + + code extraSharedEncodingDeclaration = extraBaseClassDeclaration # [{ + unsigned getVec() const; + unsigned getPerPhase() const; + unsigned getMaxPhase() const; + SmallVector getOrder() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTASplitNum() const; + SmallVector getCTAOrder() const; + bool hasLeadingOffset() const; + int32_t getAlignment() const; + }]; +} + +def SharedEncodingAttr : SharedEncoding<"SharedEncoding", "shared_encoding"> { + let mnemonic = "shared"; + + let description = [{ +The concrete attribute class for an SMEM encoding, used by multiple GPU backends. + }]; + // swizzle info: vec, perPhase, maxPhase // order: the fastest-changing axis first let parameters = ( From 46a67731f09294d8bbc274d02ef7d500916a7c34 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Jan 2025 16:33:37 +0900 Subject: [PATCH 15/17] minor --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 8 +++++--- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 2 +- .../DecomposeUnsupportedConversions.cpp | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 47354e6ca7ab..ccb562792d74 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -362,7 +362,7 @@ SmallVector getCTAOrder(Attribute layout) { } else { llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); } - llvm_unreachable("Unhandled case"); + llvm_unreachable("Unhandled case in getCTAOrder"); } SmallVector getShapePerCTA(ArrayRef CTASplitNum, @@ -2424,7 +2424,8 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { os << "shared"; return AliasResult::FinalAlias; } else if (auto sharedAttr = mlir::dyn_cast(attr)) { - llvm_unreachable("Unsupported shared encoding"); + llvm_unreachable( + "Unsupported shared encoding in TritonGPUOpAsmInterface::getAlias"); } else if (auto blockedAttr = mlir::dyn_cast(attr)) { os << "blocked"; return AliasResult::FinalAlias; @@ -2531,7 +2532,8 @@ struct TritonGPUInferLayoutInterface applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout, enc.hasLeadingOffset()); } else { - llvm_unreachable("Unsupported shared encoding"); + llvm_unreachable("Unsupported shared encoding in " + "TritonGPUInferLayoutInterface::inferTransOpEncoding"); } return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index c37b59ed0573..7d7e58f67767 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -359,7 +359,7 @@ class FuseTransMMAV3Plus : public OpRewritePattern { getContext(), srcTy.getShape(), newInnerCvtOrder, allocEncoding.getCTALayout(), srcTy.getElementType()); } else { - llvm_unreachable("Unsupported shared encoding"); + llvm_unreachable("Unsupported shared encoding in FuseTransMMAV3Plus"); } MemDescType innerTy = diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index bb82670904da..784937e2a93e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -61,7 +61,8 @@ class DecomposeLocalLoadToDotOperand triton::gpu::getOrder(parentEnc), triton::gpu::getCTALayout(parentEnc), type.getElementType()); } else { - llvm_unreachable("Unsupporeted shared encoding"); + llvm_unreachable( + "Unsupporeted shared encoding in DecomposeLocalLoadToDotOperand"); } auto newSharedDescTy = MemDescType::get(type.getShape(), type.getElementType(), newSharedEnc, From 1a3ecd0cf124b7c1f3e152f58c99422661326d48 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Mon, 27 Jan 2025 22:23:05 +0000 Subject: [PATCH 16/17] Support shared encoding defined with linear layout This PR enables basic support for linear layout used as a shared encoding. --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 3 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 59 ++++++++++++++++--- .../TritonGPU/IR/LinearLayoutConversions.cpp | 4 +- test/Conversion/amd/tritongpu_to_llvm.mlir | 22 +++++++ .../SharedToDotOperandMFMA.cpp | 6 +- 5 files changed, 84 insertions(+), 10 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 3e8fe3fedf15..1a50155f0ff3 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -244,7 +244,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, // We propose case 2 (see comments below), which provides a more general // solution for all swizzled shared memory scenarios, including the edge case // mentioned above. - if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 + if (sharedEnc && + isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout, {{kRegister, regId}, {kLane, laneId}, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index ccb562792d74..b3051fe7f395 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1185,20 +1185,46 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "}>"; } +//===----------------------------------------------------------------------===// +// Linear Encoding +//===----------------------------------------------------------------------===// + +namespace { + +bool isMemoryEncoding(const mlir::triton::LinearLayout &ll) { + return (ll.getInDimNames().begin()->str() == "offset"); +} + +bool isDistributedEncoding(const mlir::triton::LinearLayout &ll) { + return (ll.getInDimNames().begin()->str() == "register"); +} + +} // namespace + // FIXME Can we take the LinearLayout by const&? LogicalResult LinearEncodingAttr::verify(function_ref emitError, LinearLayout linearLayout) { - // Example of LinearEncodingAttr + // Examples of LinearEncodingAttr // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], // warp = [[16, 0], [32, 0]], // block = []}> - // The input dims must be {register, lane, warp, block} + // <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4, + // 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} or {offset, block} // The output dims of the linear layout should be dim0..dim[rank-1] - static const auto expectedInDims = - SmallVector({"register", "lane", "warp", "block"}); + SmallVector expectedInDims; + // check if this is a memory of distributed layout + if (isMemoryEncoding(linearLayout)) { + expectedInDims = {"offset", "block"}; + } else { + assert(isDistributedEncoding(linearLayout)); + expectedInDims = {"register", "lane", "warp", "block"}; + } + for (const auto &[i, dims] : llvm::enumerate( llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { const auto &[dim, expectedDimStr] = dims; @@ -1266,12 +1292,20 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { LinearLayout::BasesT bases; // Parse the basis names in order (the order is relevant) - std::vector inDimNames = {"register", "lane", "warp", "block"}; + std::vector inDimNames = {"offset", "register", "lane", "warp", + "block"}; for (const auto &inDimNameStr : inDimNames) { auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); Attribute value = dict.get(inDimName); + // there are two cases: + // memory layout: {offset, block} + // and distributed layout: {register, lane, warp, block} + // this check skips redundant fields + if (!value) + continue; + // Expecting an array of arrays auto arrayOfArraysAttr = mlir::dyn_cast(value); if (!arrayOfArraysAttr) { @@ -1471,14 +1505,20 @@ SmallVector LinearEncodingAttr::getSizePerThread() const { } SmallVector LinearEncodingAttr::getOrder() const { - auto rank = getLinearLayout().getNumOutDims(); + auto linearLayout = getLinearLayout(); + auto rank = linearLayout.getNumOutDims(); SmallVector order(rank); // Choose [rank-1, rank-2, ... 0] as the default order in case // there are dims that do not move in the register // This order is as good as any really std::iota(order.rbegin(), order.rend(), 0); - return orderPerDim(StringAttr::get(getContext(), "register"), order); + if (isMemoryEncoding(linearLayout)) { + return orderPerDim(StringAttr::get(getContext(), "offset"), order); + } else { + assert(isDistributedEncoding(linearLayout)); + return orderPerDim(StringAttr::get(getContext(), "register"), order); + } } LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { @@ -1486,6 +1526,11 @@ LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); llvm::SmallDenseMap namedShape; llvm::SmallVector permutedDims; + if (isMemoryEncoding(ll)) { + // memory encoding should be 1-to-1 compatible with tensor size + // TODO: check layout size equal to shape + return ll; + } for (auto dim : getRepOrder()) { permutedDims.push_back(canonicalDims[dim]); namedShape[canonicalDims[dim]] = shape[dim]; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index e91398c9db82..2d8e10021028 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -885,7 +885,9 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, // Layouts are distributed or shared in triton core // To add a new layout add an else-if clause LinearLayout result = LinearLayout::empty(); - if (auto distributed = dyn_cast(layout)) { + if (auto linearLayout = dyn_cast(layout)) { + result = linearLayout.toLinearLayout(shape); + } else if (auto distributed = dyn_cast(layout)) { result = distributed.toLinearLayout(shape); } else { auto shared = dyn_cast(layout); diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 144d32c71df8..ccfb8612f243 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -294,3 +294,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], block = []}> +#trasposed_in_regs = #ttg.linear<{register = [[1, 0], [2, 0], [0, 1], [0, 2], [0, 4]], lane = [[0, 8], [0, 16], [0, 32], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @custom_shared_layout(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + + %opA = ttg.convert_layout %arg0 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + + %trasposed_in_regs = ttg.convert_layout %arg1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #trasposed_in_regs> + %allocated = ttg.local_alloc %trasposed_in_regs : (tensor<128x64xf16, #trasposed_in_regs>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %opB = ttg.local_load %allocated : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + + %result = tt.dot %opA, %opB, %cst_0 : tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 7b885a75a111..088349cd5d49 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -214,7 +214,11 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); - auto sharedLayout = cast(aTensorTy.getEncoding()); + auto sharedLayout = dyn_cast(aTensorTy.getEncoding()); + if (!sharedLayout) { + // this converter supports only shared encoding + return Value(); + } auto order = sharedLayout.getOrder(); assert((rank == 2 || order[2] == 0) && "expect batch to be the slowest dimension"); From c24e6bfec49c5052bf4eb8d275a2d464b6311336 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 31 Jan 2025 19:16:13 +0000 Subject: [PATCH 17/17] introduce separate layout for linearShared --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 28 ++ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 3 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 288 +++++++++++++++--- .../TritonGPU/IR/LinearLayoutConversions.cpp | 15 +- test/Conversion/amd/tritongpu_to_llvm.mlir | 22 -- test/Conversion/tritongpu_to_llvm.mlir | 13 + 6 files changed, 298 insertions(+), 71 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 73413c11e59e..43cfb75dbba0 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -666,6 +666,34 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding" let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// Shared Linear Encoding +//===----------------------------------------------------------------------===// + +def SharedLinearEncodingAttr : SharedEncoding<"SharedLinearEncoding", "shared_linear_encoding"> { + let mnemonic = "shared_linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = ( + ins + LinearLayoutParam:$linearLayout, + "int32_t":$alignment__ + ); + + let extraClassDeclaration = extraSharedEncodingDeclaration; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Blocked Layout Encoding diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 1a50155f0ff3..be14af60896b 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -244,7 +244,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, // We propose case 2 (see comments below), which provides a more general // solution for all swizzled shared memory scenarios, including the edge case // mentioned above. - if (sharedEnc && + if (mlir::isa(sharedTy.getEncoding()) && + sharedEnc && isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout, {{kRegister, regId}, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index b3051fe7f395..04792b5b0cb1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1189,18 +1189,6 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { // Linear Encoding //===----------------------------------------------------------------------===// -namespace { - -bool isMemoryEncoding(const mlir::triton::LinearLayout &ll) { - return (ll.getInDimNames().begin()->str() == "offset"); -} - -bool isDistributedEncoding(const mlir::triton::LinearLayout &ll) { - return (ll.getInDimNames().begin()->str() == "register"); -} - -} // namespace - // FIXME Can we take the LinearLayout by const&? LogicalResult LinearEncodingAttr::verify(function_ref emitError, @@ -1210,20 +1198,12 @@ LinearEncodingAttr::verify(function_ref emitError, // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], // warp = [[16, 0], [32, 0]], // block = []}> - // <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4, - // 0]], - // block = []}> - // The input dims must be {register, lane, warp, block} or {offset, block} + // The input dims must be {register, lane, warp, block} // The output dims of the linear layout should be dim0..dim[rank-1] - SmallVector expectedInDims; + SmallVector expectedInDims = {"register", "lane", "warp", + "block"}; // check if this is a memory of distributed layout - if (isMemoryEncoding(linearLayout)) { - expectedInDims = {"offset", "block"}; - } else { - assert(isDistributedEncoding(linearLayout)); - expectedInDims = {"register", "lane", "warp", "block"}; - } for (const auto &[i, dims] : llvm::enumerate( llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { @@ -1292,20 +1272,12 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { LinearLayout::BasesT bases; // Parse the basis names in order (the order is relevant) - std::vector inDimNames = {"offset", "register", "lane", "warp", - "block"}; + std::vector inDimNames = {"register", "lane", "warp", "block"}; for (const auto &inDimNameStr : inDimNames) { auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); Attribute value = dict.get(inDimName); - // there are two cases: - // memory layout: {offset, block} - // and distributed layout: {register, lane, warp, block} - // this check skips redundant fields - if (!value) - continue; - // Expecting an array of arrays auto arrayOfArraysAttr = mlir::dyn_cast(value); if (!arrayOfArraysAttr) { @@ -1513,12 +1485,7 @@ SmallVector LinearEncodingAttr::getOrder() const { // This order is as good as any really std::iota(order.rbegin(), order.rend(), 0); - if (isMemoryEncoding(linearLayout)) { - return orderPerDim(StringAttr::get(getContext(), "offset"), order); - } else { - assert(isDistributedEncoding(linearLayout)); - return orderPerDim(StringAttr::get(getContext(), "register"), order); - } + return orderPerDim(StringAttr::get(getContext(), "register"), order); } LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { @@ -1526,11 +1493,6 @@ LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); llvm::SmallDenseMap namedShape; llvm::SmallVector permutedDims; - if (isMemoryEncoding(ll)) { - // memory encoding should be 1-to-1 compatible with tensor size - // TODO: check layout size equal to shape - return ll; - } for (auto dim : getRepOrder()) { permutedDims.push_back(canonicalDims[dim]); namedShape[canonicalDims[dim]] = shape[dim]; @@ -1582,6 +1544,246 @@ unsigned LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } +//===----------------------------------------------------------------------===// +// Shared Linear Encoding +//===----------------------------------------------------------------------===// + +LogicalResult +SharedLinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout, int32_t alignment) { + // Examples of LinearEncodingAttr + // <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0]], block = []}> + // The input dims must be {offset, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + SmallVector expectedInDims = {"offset", "block"}; + + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + return success(); +} + +void SharedLinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + // We don't use the default implementation as it's a bit too verbose + // This prints in the following format that is shape agnostic, in the sense + // that we don't print explicitly the outShape of the LL + // We always assume LLs to be surjective + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + auto ll = getLinearLayout(); + printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }) << "}>"; +} + +Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + std::vector inDimNames = {"offset", "block"}; + + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + // To implement this we'd need to serialise the rank as well. + // We can do this if we ever need it + if (rank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + // Generate standard outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); + + // Parse alignment attribute + auto alignmentStr = StringAttr::get(parser.getContext(), "alignment"); + unsigned alignment = 0; + auto alignmentAttr = dict.getNamed(alignmentStr); + if (!alignmentAttr.has_value()) { + // TODO: report error + return {}; + } + + if (parseUInt(parser, alignmentAttr.value(), alignment, "alignment").failed()) + return {}; + + // Create and return the LinearEncodingAttr + return parser.getChecked( + parser.getContext(), std::move(linearLayout), alignment); +} + +unsigned SharedLinearEncodingAttr::getVec() const { + assert(false); + return 0; +}; +unsigned SharedLinearEncodingAttr::getPerPhase() const { + assert(false); + return 0; +}; +unsigned SharedLinearEncodingAttr::getMaxPhase() const { + assert(false); + return 0; +}; +SmallVector SharedLinearEncodingAttr::getOrder() const { + auto dimName = StringAttr::get(getContext(), "offset"); + auto ll = getLinearLayout(); + const auto &bases = ll.getBases().find(dimName)->second; + + unsigned rank = bases.begin()->size(); + auto defaultOrder = llvm::iota_range(0u, rank, false); + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return SmallVector(order.begin(), order.end()); +}; +SmallVector SharedLinearEncodingAttr::getCTAsPerCGA() const { + // TODO replace with block abalysis? + auto rank = getLinearLayout() + .getBasis(StringAttr::get(getContext(), "offset"), 0) + .size(); + return SmallVector(rank, 1); +}; +SmallVector SharedLinearEncodingAttr::getCTASplitNum() const { + // TODO replace with block abalysis? + auto rank = getLinearLayout() + .getBasis(StringAttr::get(getContext(), "offset"), 0) + .size(); + return SmallVector(rank, 1); +}; +SmallVector SharedLinearEncodingAttr::getCTAOrder() const { + // TODO replace with block abalysis? + auto rank = getLinearLayout() + .getBasis(StringAttr::get(getContext(), "offset"), 0) + .size(); + return llvm::to_vector(llvm::iota_range(0, rank, false)); +}; +bool SharedLinearEncodingAttr::hasLeadingOffset() const { + // assert(false); + return false; +}; +int32_t SharedLinearEncodingAttr::getAlignment() const { + return this->getAlignment__(); +}; +SmallVector +SharedLinearEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return {}; +} +unsigned +SharedLinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return 0; +} + //===----------------------------------------------------------------------===// // MMA encoding //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2d8e10021028..e43210671440 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -890,12 +890,17 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, } else if (auto distributed = dyn_cast(layout)) { result = distributed.toLinearLayout(shape); } else { - auto shared = dyn_cast(layout); - if (shared.hasLeadingOffset()) { - assert(elemBitWidth.has_value()); - result = sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); + if (auto linearEncoding = dyn_cast(layout)) { + result = linearEncoding.getLinearLayout(); } else { - result = sharedToLinearLayoutNoLeadingOffset(shape, shared); + auto shared = dyn_cast(layout); + if (shared.hasLeadingOffset()) { + assert(elemBitWidth.has_value()); + result = + sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); + } else { + result = sharedToLinearLayoutNoLeadingOffset(shape, shared); + } } } diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index ccfb8612f243..144d32c71df8 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -294,25 +294,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [4, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> -#shared = #ttg.linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], block = []}> -#trasposed_in_regs = #ttg.linear<{register = [[1, 0], [2, 0], [0, 1], [0, 2], [0, 4]], lane = [[0, 8], [0, 16], [0, 32], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @custom_shared_layout(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked1>) { - %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> - - %opA = ttg.convert_layout %arg0 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - - %trasposed_in_regs = ttg.convert_layout %arg1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #trasposed_in_regs> - %allocated = ttg.local_alloc %trasposed_in_regs : (tensor<128x64xf16, #trasposed_in_regs>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %opB = ttg.local_load %allocated : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - - %result = tt.dot %opA, %opB, %cst_0 : tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma> - tt.return - } -} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index f60521d4269b..9f34cbda49d2 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2265,3 +2265,16 @@ tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #li } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], block = [], alignment = 16}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: shared_linear_store_load + tt.func public @shared_linear_store_load(%arg0: tensor<128x64xf16, #blocked>) { + %0 = ttg.local_alloc %arg0 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_load %0 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #blocked> + tt.return + } +}