diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 73413c11e59e..43cfb75dbba0 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -666,6 +666,34 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding" let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// Shared Linear Encoding +//===----------------------------------------------------------------------===// + +def SharedLinearEncodingAttr : SharedEncoding<"SharedLinearEncoding", "shared_linear_encoding"> { + let mnemonic = "shared_linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = ( + ins + LinearLayoutParam:$linearLayout, + "int32_t":$alignment__ + ); + + let extraClassDeclaration = extraSharedEncodingDeclaration; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Blocked Layout Encoding diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 1a50155f0ff3..be14af60896b 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -244,7 +244,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, // We propose case 2 (see comments below), which provides a more general // solution for all swizzled shared memory scenarios, including the edge case // mentioned above. - if (sharedEnc && + if (mlir::isa(sharedTy.getEncoding()) && + sharedEnc && isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout, {{kRegister, regId}, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index b3051fe7f395..04792b5b0cb1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1189,18 +1189,6 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { // Linear Encoding //===----------------------------------------------------------------------===// -namespace { - -bool isMemoryEncoding(const mlir::triton::LinearLayout &ll) { - return (ll.getInDimNames().begin()->str() == "offset"); -} - -bool isDistributedEncoding(const mlir::triton::LinearLayout &ll) { - return (ll.getInDimNames().begin()->str() == "register"); -} - -} // namespace - // FIXME Can we take the LinearLayout by const&? LogicalResult LinearEncodingAttr::verify(function_ref emitError, @@ -1210,20 +1198,12 @@ LinearEncodingAttr::verify(function_ref emitError, // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], // warp = [[16, 0], [32, 0]], // block = []}> - // <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4, - // 0]], - // block = []}> - // The input dims must be {register, lane, warp, block} or {offset, block} + // The input dims must be {register, lane, warp, block} // The output dims of the linear layout should be dim0..dim[rank-1] - SmallVector expectedInDims; + SmallVector expectedInDims = {"register", "lane", "warp", + "block"}; // check if this is a memory of distributed layout - if (isMemoryEncoding(linearLayout)) { - expectedInDims = {"offset", "block"}; - } else { - assert(isDistributedEncoding(linearLayout)); - expectedInDims = {"register", "lane", "warp", "block"}; - } for (const auto &[i, dims] : llvm::enumerate( llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { @@ -1292,20 +1272,12 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { LinearLayout::BasesT bases; // Parse the basis names in order (the order is relevant) - std::vector inDimNames = {"offset", "register", "lane", "warp", - "block"}; + std::vector inDimNames = {"register", "lane", "warp", "block"}; for (const auto &inDimNameStr : inDimNames) { auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); Attribute value = dict.get(inDimName); - // there are two cases: - // memory layout: {offset, block} - // and distributed layout: {register, lane, warp, block} - // this check skips redundant fields - if (!value) - continue; - // Expecting an array of arrays auto arrayOfArraysAttr = mlir::dyn_cast(value); if (!arrayOfArraysAttr) { @@ -1513,12 +1485,7 @@ SmallVector LinearEncodingAttr::getOrder() const { // This order is as good as any really std::iota(order.rbegin(), order.rend(), 0); - if (isMemoryEncoding(linearLayout)) { - return orderPerDim(StringAttr::get(getContext(), "offset"), order); - } else { - assert(isDistributedEncoding(linearLayout)); - return orderPerDim(StringAttr::get(getContext(), "register"), order); - } + return orderPerDim(StringAttr::get(getContext(), "register"), order); } LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { @@ -1526,11 +1493,6 @@ LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); llvm::SmallDenseMap namedShape; llvm::SmallVector permutedDims; - if (isMemoryEncoding(ll)) { - // memory encoding should be 1-to-1 compatible with tensor size - // TODO: check layout size equal to shape - return ll; - } for (auto dim : getRepOrder()) { permutedDims.push_back(canonicalDims[dim]); namedShape[canonicalDims[dim]] = shape[dim]; @@ -1582,6 +1544,246 @@ unsigned LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } +//===----------------------------------------------------------------------===// +// Shared Linear Encoding +//===----------------------------------------------------------------------===// + +LogicalResult +SharedLinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout, int32_t alignment) { + // Examples of LinearEncodingAttr + // <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0]], block = []}> + // The input dims must be {offset, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + SmallVector expectedInDims = {"offset", "block"}; + + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + return success(); +} + +void SharedLinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + // We don't use the default implementation as it's a bit too verbose + // This prints in the following format that is shape agnostic, in the sense + // that we don't print explicitly the outShape of the LL + // We always assume LLs to be surjective + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + auto ll = getLinearLayout(); + printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }) << "}>"; +} + +Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + std::vector inDimNames = {"offset", "block"}; + + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + // To implement this we'd need to serialise the rank as well. + // We can do this if we ever need it + if (rank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + // Generate standard outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); + + // Parse alignment attribute + auto alignmentStr = StringAttr::get(parser.getContext(), "alignment"); + unsigned alignment = 0; + auto alignmentAttr = dict.getNamed(alignmentStr); + if (!alignmentAttr.has_value()) { + // TODO: report error + return {}; + } + + if (parseUInt(parser, alignmentAttr.value(), alignment, "alignment").failed()) + return {}; + + // Create and return the LinearEncodingAttr + return parser.getChecked( + parser.getContext(), std::move(linearLayout), alignment); +} + +unsigned SharedLinearEncodingAttr::getVec() const { + assert(false); + return 0; +}; +unsigned SharedLinearEncodingAttr::getPerPhase() const { + assert(false); + return 0; +}; +unsigned SharedLinearEncodingAttr::getMaxPhase() const { + assert(false); + return 0; +}; +SmallVector SharedLinearEncodingAttr::getOrder() const { + auto dimName = StringAttr::get(getContext(), "offset"); + auto ll = getLinearLayout(); + const auto &bases = ll.getBases().find(dimName)->second; + + unsigned rank = bases.begin()->size(); + auto defaultOrder = llvm::iota_range(0u, rank, false); + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return SmallVector(order.begin(), order.end()); +}; +SmallVector SharedLinearEncodingAttr::getCTAsPerCGA() const { + // TODO replace with block abalysis? + auto rank = getLinearLayout() + .getBasis(StringAttr::get(getContext(), "offset"), 0) + .size(); + return SmallVector(rank, 1); +}; +SmallVector SharedLinearEncodingAttr::getCTASplitNum() const { + // TODO replace with block abalysis? + auto rank = getLinearLayout() + .getBasis(StringAttr::get(getContext(), "offset"), 0) + .size(); + return SmallVector(rank, 1); +}; +SmallVector SharedLinearEncodingAttr::getCTAOrder() const { + // TODO replace with block abalysis? + auto rank = getLinearLayout() + .getBasis(StringAttr::get(getContext(), "offset"), 0) + .size(); + return llvm::to_vector(llvm::iota_range(0, rank, false)); +}; +bool SharedLinearEncodingAttr::hasLeadingOffset() const { + // assert(false); + return false; +}; +int32_t SharedLinearEncodingAttr::getAlignment() const { + return this->getAlignment__(); +}; +SmallVector +SharedLinearEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return {}; +} +unsigned +SharedLinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return 0; +} + //===----------------------------------------------------------------------===// // MMA encoding //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2d8e10021028..e43210671440 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -890,12 +890,17 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, } else if (auto distributed = dyn_cast(layout)) { result = distributed.toLinearLayout(shape); } else { - auto shared = dyn_cast(layout); - if (shared.hasLeadingOffset()) { - assert(elemBitWidth.has_value()); - result = sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); + if (auto linearEncoding = dyn_cast(layout)) { + result = linearEncoding.getLinearLayout(); } else { - result = sharedToLinearLayoutNoLeadingOffset(shape, shared); + auto shared = dyn_cast(layout); + if (shared.hasLeadingOffset()) { + assert(elemBitWidth.has_value()); + result = + sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); + } else { + result = sharedToLinearLayoutNoLeadingOffset(shape, shared); + } } } diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index ccfb8612f243..144d32c71df8 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -294,25 +294,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [4, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> -#shared = #ttg.linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], block = []}> -#trasposed_in_regs = #ttg.linear<{register = [[1, 0], [2, 0], [0, 1], [0, 2], [0, 4]], lane = [[0, 8], [0, 16], [0, 32], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @custom_shared_layout(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked1>) { - %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> - - %opA = ttg.convert_layout %arg0 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - - %trasposed_in_regs = ttg.convert_layout %arg1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #trasposed_in_regs> - %allocated = ttg.local_alloc %trasposed_in_regs : (tensor<128x64xf16, #trasposed_in_regs>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %opB = ttg.local_load %allocated : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - - %result = tt.dot %opA, %opB, %cst_0 : tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma> - tt.return - } -} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index f60521d4269b..9f34cbda49d2 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2265,3 +2265,16 @@ tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #li } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], block = [], alignment = 16}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: shared_linear_store_load + tt.func public @shared_linear_store_load(%arg0: tensor<128x64xf16, #blocked>) { + %0 = ttg.local_alloc %arg0 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_load %0 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #blocked> + tt.return + } +}