From a91f8e9d24a3a549d4005571ae3c2ce94c85745e Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Mon, 27 Jan 2025 22:23:05 +0000 Subject: [PATCH] Support shared encoding defined with linear layout This PR enables basic support for linear layout used as a shared encoding. --- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 3 - lib/Conversion/TritonGPUToLLVM/Utility.cpp | 8 ++- lib/Dialect/TritonGPU/IR/Dialect.cpp | 59 ++++++++++++++++--- .../TritonGPU/IR/LinearLayoutConversions.cpp | 4 +- test/Conversion/amd/tritongpu_to_llvm.mlir | 22 +++++++ .../SharedToDotOperandMFMA.cpp | 6 +- 6 files changed, 87 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 47e2cea970b7..dd7a747c44e8 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -22,7 +22,6 @@ void lowerDistributedToShared( std::pair *const llvmOpCount = nullptr) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); - auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); auto elemTy = typeConverter->convertType(srcTy.getElementType()); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); @@ -76,8 +75,6 @@ struct LocalAllocOpConversion LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto resultTy = cast(op.getType()); auto typeConverter = getTypeConverter(); - auto sharedLayout = - cast(resultTy.getEncoding()); auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(), diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 4d25656199dd..037852aca6f6 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -240,7 +240,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, // We propose case 2 (see comments below), which provides a more general // solution for all swizzled shared memory scenarios, including the edge case // mentioned above. - if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 + if (sharedEnc && + isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout, {{kRegister, regId}, {kLane, laneId}, @@ -249,8 +250,9 @@ Value getSmemVecAddr(const LinearLayout ®Layout, .second; } else { // Case 2 -> rank-reduced swizzling assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2"); - assert(!sharedEnc.getHasLeadingOffset() && - "Leading offsets are not supported for sliced tensors"); + assert(!sharedEnc || + !sharedEnc.getHasLeadingOffset() && + "Leading offsets are not supported for sliced tensors"); // We define both tensor offsets and shared memory offsets: // // - Tensor offsets: Relative offsets within a given tensor. diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index cddcf142b2d5..c1f78ccd9480 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1139,20 +1139,46 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "}>"; } +//===----------------------------------------------------------------------===// +// Linear Encoding +//===----------------------------------------------------------------------===// + +namespace { + +bool isMemoryEncoding(const mlir::triton::LinearLayout &ll) { + return (ll.getInDimNames().begin()->str() == "offset"); +} + +bool isDistributedEncoding(const mlir::triton::LinearLayout &ll) { + return (ll.getInDimNames().begin()->str() == "register"); +} + +} // namespace + // FIXME Can we take the LinearLayout by const&? LogicalResult LinearEncodingAttr::verify(function_ref emitError, LinearLayout linearLayout) { - // Example of LinearEncodingAttr + // Examples of LinearEncodingAttr // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], // warp = [[16, 0], [32, 0]], // block = []}> - // The input dims must be {register, lane, warp, block} + // <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4, + // 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} or {offset, block} // The output dims of the linear layout should be dim0..dim[rank-1] - static const auto expectedInDims = - SmallVector({"register", "lane", "warp", "block"}); + SmallVector expectedInDims; + // check if this is a memory of distributed layout + if (isMemoryEncoding(linearLayout)) { + expectedInDims = {"offset", "block"}; + } else { + assert(isDistributedEncoding(linearLayout)); + expectedInDims = {"register", "lane", "warp", "block"}; + } + for (const auto &[i, dims] : llvm::enumerate( llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { const auto &[dim, expectedDimStr] = dims; @@ -1220,12 +1246,20 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { LinearLayout::BasesT bases; // Parse the basis names in order (the order is relevant) - std::vector inDimNames = {"register", "lane", "warp", "block"}; + std::vector inDimNames = {"offset", "register", "lane", "warp", + "block"}; for (const auto &inDimNameStr : inDimNames) { auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); Attribute value = dict.get(inDimName); + // there are two cases: + // memory layout: {offset, block} + // and distributed layout: {register, lane, warp, block} + // this check skips redundant fields + if (!value) + continue; + // Expecting an array of arrays auto arrayOfArraysAttr = mlir::dyn_cast(value); if (!arrayOfArraysAttr) { @@ -1425,14 +1459,20 @@ SmallVector LinearEncodingAttr::getSizePerThread() const { } SmallVector LinearEncodingAttr::getOrder() const { - auto rank = getLinearLayout().getNumOutDims(); + auto linearLayout = getLinearLayout(); + auto rank = linearLayout.getNumOutDims(); SmallVector order(rank); // Choose [rank-1, rank-2, ... 0] as the default order in case // there are dims that do not move in the register // This order is as good as any really std::iota(order.rbegin(), order.rend(), 0); - return orderPerDim(StringAttr::get(getContext(), "register"), order); + if (isMemoryEncoding(linearLayout)) { + return orderPerDim(StringAttr::get(getContext(), "offset"), order); + } else { + assert(isDistributedEncoding(linearLayout)); + return orderPerDim(StringAttr::get(getContext(), "register"), order); + } } LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { @@ -1440,6 +1480,11 @@ LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); llvm::SmallDenseMap namedShape; llvm::SmallVector permutedDims; + if (isMemoryEncoding(ll)) { + // memory encoding should be 1-to-1 compatible with tensor size + // TODO: check layout size equal to shape + return ll; + } for (auto dim : getRepOrder()) { permutedDims.push_back(canonicalDims[dim]); namedShape[canonicalDims[dim]] = shape[dim]; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 3abd7a114351..1a53e8e9d08d 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -879,7 +879,9 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, // Layouts are distributed or shared in triton core // To add a new layout add an else-if clause LinearLayout result = LinearLayout::empty(); - if (auto distributed = dyn_cast(layout)) { + if (auto linearLayout = dyn_cast(layout)) { + result = linearLayout.toLinearLayout(shape); + } else if (auto distributed = dyn_cast(layout)) { result = distributed.toLinearLayout(shape); } else { auto shared = dyn_cast(layout); diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 144d32c71df8..ccfb8612f243 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -294,3 +294,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], block = []}> +#trasposed_in_regs = #ttg.linear<{register = [[1, 0], [2, 0], [0, 1], [0, 2], [0, 4]], lane = [[0, 8], [0, 16], [0, 32], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @custom_shared_layout(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + + %opA = ttg.convert_layout %arg0 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + + %trasposed_in_regs = ttg.convert_layout %arg1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #trasposed_in_regs> + %allocated = ttg.local_alloc %trasposed_in_regs : (tensor<128x64xf16, #trasposed_in_regs>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %opB = ttg.local_load %allocated : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + + %result = tt.dot %opA, %opB, %cst_0 : tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index e4a683d870f8..61b90e48a43a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -210,7 +210,11 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); - auto sharedLayout = cast(aTensorTy.getEncoding()); + auto sharedLayout = dyn_cast(aTensorTy.getEncoding()); + if (!sharedLayout) { + // this converter supports only shared encoding + return Value(); + } auto order = sharedLayout.getOrder(); assert((rank == 2 || order[2] == 0) && "expect batch to be the slowest dimension");