Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support shared encoding defined with linear layout #5720

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
// toLinearLayout interface.
mlir::Attribute layout = tensorType.getEncoding();
if (isa<mlir::triton::gpu::DistributedEncodingTrait,
mlir::triton::gpu::SharedEncodingAttr>(layout)) {
mlir::triton::gpu::SharedEncodingTrait>(layout)) {
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
}
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ inline Value packLLVector(Location loc, ValueRange vals,
inline bool
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
ArrayRef<int64_t> allocShape,
triton::gpu::SharedEncodingAttr sharedEnc) {
triton::gpu::SharedEncodingTrait sharedEnc) {
auto rank = shape.size();
return /*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
/*swizzling but same shape*/ shape == allocShape ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "triton/Tools/LinearLayout.h"

namespace mlir::triton::gpu {
class SharedEncodingAttr;
class SharedEncodingTrait;

// - BlockedEncodingAttrs have the following input dimensions.
//
Expand Down Expand Up @@ -48,7 +48,7 @@ LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
// If `disableSwizzle` is set, then the resulting layout does not include
// swizzling.
LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
SharedEncodingAttr shared,
SharedEncodingTrait shared,
int32_t elemBitWidth,
bool disableSwizzle = false);

Expand Down
106 changes: 95 additions & 11 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,8 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//

def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> {
let mnemonic = "shared";
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
let cppNamespace = "::mlir::triton::gpu";

let description = [{
An encoding for tensors whose elements may be simultaneously accessed by
Expand Down Expand Up @@ -235,16 +234,75 @@ a matrix of size 16x128 and data type I8 is stored in the shared memory with
compared to 1*64 when the hasLeadingOffset is false.
}];

let methods = [
InterfaceMethod<"Get 'vec', one of properties of swizzling.",
"unsigned",
"getVec">,
InterfaceMethod<"Get 'per phase', one of properties of swizzling.",
"unsigned",
"getPerPhase">,
InterfaceMethod<"Get 'max phase', one of properties of swizzling.",
"unsigned",
"getMaxPhase">,
InterfaceMethod<"Get the order of this SMEM encoding.",
"SmallVector<unsigned>",
"getOrder">,
InterfaceMethod<"Get the shape of the CTAs per CGA.",
"SmallVector<unsigned>",
"getCTAsPerCGA">,
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
"SmallVector<unsigned>",
"getCTASplitNum">,
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
"SmallVector<unsigned>",
"getCTAOrder">,
InterfaceMethod<"Get the size of the address alignment in bytes.",
"int32_t",
"getAlignment">,
InterfaceMethod<"True if the SMEM layout is in the core-matrices format.",
"bool",
"hasLeadingOffset">
];
}

class SharedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
Dialect dialect = TritonGPU_Dialect>
: TritonGPU_Attr<name, attrMnemonic, !listconcat([SharedEncodingTrait], traits), dialect> {

let description = [{
The interface which can be inherited by various representations of an SMEM encoding.
}];

code extraSharedEncodingDeclaration = extraBaseClassDeclaration # [{
unsigned getVec() const;
unsigned getPerPhase() const;
unsigned getMaxPhase() const;
SmallVector<unsigned> getOrder() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTASplitNum() const;
SmallVector<unsigned> getCTAOrder() const;
bool hasLeadingOffset() const;
int32_t getAlignment() const;
}];
}

def SharedEncodingAttr : SharedEncoding<"SharedEncoding", "shared_encoding"> {
let mnemonic = "shared";

let description = [{
The concrete attribute class for an SMEM encoding, used by multiple GPU backends.
}];

// swizzle info: vec, perPhase, maxPhase
// order: the fastest-changing axis first
let parameters = (
ins
"unsigned":$vec,
"unsigned":$perPhase,
"unsigned":$maxPhase,
ArrayRefParameter<"unsigned">:$order,
"unsigned":$vec__,
"unsigned":$perPhase__,
"unsigned":$maxPhase__,
ArrayRefParameter<"unsigned">:$order__,
"CTALayoutAttr":$CTALayout,
"bool":$hasLeadingOffset
"bool":$hasLeadingOffset__
);

let builders = [
Expand Down Expand Up @@ -432,9 +490,7 @@ compared to 1*64 when the hasLeadingOffset is false.
}]>
];

let extraClassDeclaration = extraBaseClassDeclaration # [{
int32_t getAlignment() const;
}];
let extraClassDeclaration = extraSharedEncodingDeclaration;
let hasCustomAssemblyFormat = 1;
}

Expand Down Expand Up @@ -610,6 +666,34 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Shared Linear Encoding
//===----------------------------------------------------------------------===//

def SharedLinearEncodingAttr : SharedEncoding<"SharedLinearEncoding", "shared_linear_encoding"> {
let mnemonic = "shared_linear";

let description = [{
See the docs in LinearLayout.h for the definition of linear layouts.
}];

let parameters = (
ins
LinearLayoutParam:$linearLayout,
"int32_t":$alignment__
);

let extraClassDeclaration = extraSharedEncodingDeclaration;

let genVerifyDecl = 1;
// Example of assembly format:
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
// warp = [[16, 0], [32, 0]],
// block = []}>
let hasCustomAssemblyFormat = 1;
}


//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ bool isPureUnaryInlineAsm(Operation *op);
// read the compute capability from the module attributes
int getNVIDIAComputeCapability(Operation *module);

std::optional<mlir::triton::gpu::SharedEncodingAttr>
std::optional<mlir::triton::gpu::SharedEncodingTrait>
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

enum class MMALoadType {
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
auto dstTy = cvtLayout.getType();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
if (mlir::isa<gpu::SharedEncodingTrait>(srcEncoding) ||
mlir::isa<gpu::SharedEncodingTrait>(dstEncoding)) {
// Conversions from/to shared memory do not need scratch memory.
return 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::MemDescType;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingTrait;

Value getStructFromValueTable(ArrayRef<Value> vals,
ConversionPatternRewriter &rewriter, Location loc,
Expand All @@ -32,11 +32,13 @@ Value getStructFromValueTable(ArrayRef<Value> vals,
return packLLElements(loc, typeConverter, elems, rewriter, structTy);
}

bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; }
bool isSwizzled(SharedEncodingTrait layout) {
return layout.getMaxPhase() != 1;
}

SmallVector<Value> swizzleIndices(ConversionPatternRewriter &rewriter,
Location loc, SmallVector<Value> rawIndices,
SharedEncodingAttr layout) {
SharedEncodingTrait layout) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
const auto &order = layout.getOrder();
auto rank = order.size();
Expand Down Expand Up @@ -154,7 +156,7 @@ Value computeSwizzledOffset(ConversionPatternRewriter &rewriter, Location loc,
Value bTileOffset, Value nonKTileOffset,
unsigned shapePerCTABTile,
unsigned shapePerCTANonKTile,
SharedEncodingAttr sharedLayout,
SharedEncodingTrait sharedLayout,
ArrayRef<int64_t> opTensorShape,
ArrayRef<Value> strides) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Expand Down Expand Up @@ -229,7 +231,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
dim.nonK = dotOpNo == 0 ? 1 : 2;
auto opTensorTy = cast<MemDescType>(srcVal.getType());
auto opTensorShape = expandMatrixShapeWithBatch(opTensorTy.getShape());
auto sharedLayout = cast<SharedEncodingAttr>(opTensorTy.getEncoding());
auto sharedLayout = cast<SharedEncodingTrait>(opTensorTy.getEncoding());

auto opOrder = expandMatrixOrderWithBatch(dLayout.getOrder());

Expand Down
2 changes: 0 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ struct LocalAllocOpConversion
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto resultTy = cast<MemDescType>(op.getType());
auto typeConverter = getTypeConverter();
auto sharedLayout =
cast<triton::gpu::SharedEncodingAttr>(resultTy.getEncoding());

auto llvmElemTy = typeConverter->convertType(resultTy.getElementType());
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(),
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ using namespace mlir::triton;

using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::MemDescType;
using ::mlir::triton::gpu::SharedEncodingAttr;

TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &options,
Expand Down
8 changes: 5 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
auto allocShape = sharedTy.getAllocShape();
auto rank = shape.size();
auto sharedEnc =
dyn_cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
dyn_cast<triton::gpu::SharedEncodingTrait>(sharedTy.getEncoding());

auto smemBase = smemObj.getBase();
auto smemOffsets = smemObj.getOffsets();
Expand Down Expand Up @@ -244,7 +244,9 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
// We propose case 2 (see comments below), which provides a more general
// solution for all swizzled shared memory scenarios, including the edge case
// mentioned above.
if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
if (mlir::isa<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding()) &&
sharedEnc &&
isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout,
{{kRegister, regId},
{kLane, laneId},
Expand All @@ -253,7 +255,7 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
.second;
} else { // Case 2 -> rank-reduced swizzling
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
assert(!sharedEnc.getHasLeadingOffset() &&
assert(!sharedEnc.hasLeadingOffset() &&
"Leading offsets are not supported for sliced tensors");
// We define both tensor offsets and shared memory offsets:
//
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ struct MemDescTransOpConversion
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<TensorOrMemDesc>(op.getType());
auto enc = cast<SharedEncodingAttr>(resultTy.getEncoding());
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
Expand Down
Loading