Skip to content

Commit

Permalink
Support shared encoding defined with linear layout
Browse files Browse the repository at this point in the history
This PR enables basic support for linear layout used as a shared encoding.
  • Loading branch information
binarman committed Jan 27, 2025
1 parent 98b40d5 commit a91f8e9
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 15 deletions.
3 changes: 0 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ void lowerDistributedToShared(
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
auto elemTy = typeConverter->convertType(srcTy.getElementType());

auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
Expand Down Expand Up @@ -76,8 +75,6 @@ struct LocalAllocOpConversion
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto resultTy = cast<MemDescType>(op.getType());
auto typeConverter = getTypeConverter();
auto sharedLayout =
cast<triton::gpu::SharedEncodingAttr>(resultTy.getEncoding());

auto llvmElemTy = typeConverter->convertType(resultTy.getElementType());
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(),
Expand Down
8 changes: 5 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
// We propose case 2 (see comments below), which provides a more general
// solution for all swizzled shared memory scenarios, including the edge case
// mentioned above.
if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
if (sharedEnc &&
isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout,
{{kRegister, regId},
{kLane, laneId},
Expand All @@ -249,8 +250,9 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
.second;
} else { // Case 2 -> rank-reduced swizzling
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
assert(!sharedEnc.getHasLeadingOffset() &&
"Leading offsets are not supported for sliced tensors");
assert(!sharedEnc ||
!sharedEnc.getHasLeadingOffset() &&
"Leading offsets are not supported for sliced tensors");
// We define both tensor offsets and shared memory offsets:
//
// - Tensor offsets: Relative offsets within a given tensor.
Expand Down
59 changes: 52 additions & 7 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1139,20 +1139,46 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "}>";
}

//===----------------------------------------------------------------------===//
// Linear Encoding
//===----------------------------------------------------------------------===//

namespace {

bool isMemoryEncoding(const mlir::triton::LinearLayout &ll) {
return (ll.getInDimNames().begin()->str() == "offset");
}

bool isDistributedEncoding(const mlir::triton::LinearLayout &ll) {
return (ll.getInDimNames().begin()->str() == "register");
}

} // namespace

// FIXME Can we take the LinearLayout by const&?
LogicalResult
LinearEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LinearLayout linearLayout) {
// Example of LinearEncodingAttr
// Examples of LinearEncodingAttr
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
// warp = [[16, 0], [32, 0]],
// block = []}>
// The input dims must be {register, lane, warp, block}
// <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4,
// 0]],
// block = []}>
// The input dims must be {register, lane, warp, block} or {offset, block}
// The output dims of the linear layout should be dim0..dim[rank-1]

static const auto expectedInDims =
SmallVector<std::string>({"register", "lane", "warp", "block"});
SmallVector<std::string> expectedInDims;
// check if this is a memory of distributed layout
if (isMemoryEncoding(linearLayout)) {
expectedInDims = {"offset", "block"};
} else {
assert(isDistributedEncoding(linearLayout));
expectedInDims = {"register", "lane", "warp", "block"};
}

for (const auto &[i, dims] : llvm::enumerate(
llvm::zip(linearLayout.getInDimNames(), expectedInDims))) {
const auto &[dim, expectedDimStr] = dims;
Expand Down Expand Up @@ -1220,12 +1246,20 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) {
LinearLayout::BasesT bases;

// Parse the basis names in order (the order is relevant)
std::vector<std::string> inDimNames = {"register", "lane", "warp", "block"};
std::vector<std::string> inDimNames = {"offset", "register", "lane", "warp",
"block"};

for (const auto &inDimNameStr : inDimNames) {
auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr);
Attribute value = dict.get(inDimName);

// there are two cases:
// memory layout: {offset, block}
// and distributed layout: {register, lane, warp, block}
// this check skips redundant fields
if (!value)
continue;

// Expecting an array of arrays
auto arrayOfArraysAttr = mlir::dyn_cast<ArrayAttr>(value);
if (!arrayOfArraysAttr) {
Expand Down Expand Up @@ -1425,21 +1459,32 @@ SmallVector<unsigned> LinearEncodingAttr::getSizePerThread() const {
}

SmallVector<unsigned> LinearEncodingAttr::getOrder() const {
auto rank = getLinearLayout().getNumOutDims();
auto linearLayout = getLinearLayout();
auto rank = linearLayout.getNumOutDims();
SmallVector<unsigned> order(rank);
// Choose [rank-1, rank-2, ... 0] as the default order in case
// there are dims that do not move in the register
// This order is as good as any really
std::iota(order.rbegin(), order.rend(), 0);

return orderPerDim(StringAttr::get(getContext(), "register"), order);
if (isMemoryEncoding(linearLayout)) {
return orderPerDim(StringAttr::get(getContext(), "offset"), order);
} else {
assert(isDistributedEncoding(linearLayout));
return orderPerDim(StringAttr::get(getContext(), "register"), order);
}
}

LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ll = getLinearLayout();
auto canonicalDims = llvm::to_vector(ll.getOutDimNames());
llvm::SmallDenseMap<StringAttr, int64_t> namedShape;
llvm::SmallVector<StringAttr> permutedDims;
if (isMemoryEncoding(ll)) {
// memory encoding should be 1-to-1 compatible with tensor size
// TODO: check layout size equal to shape
return ll;
}
for (auto dim : getRepOrder()) {
permutedDims.push_back(canonicalDims[dim]);
namedShape[canonicalDims[dim]] = shape[dim];
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,9 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
// Layouts are distributed or shared in triton core
// To add a new layout add an else-if clause
LinearLayout result = LinearLayout::empty();
if (auto distributed = dyn_cast<DistributedEncodingTrait>(layout)) {
if (auto linearLayout = dyn_cast<LinearEncodingAttr>(layout)) {
result = linearLayout.toLinearLayout(shape);
} else if (auto distributed = dyn_cast<DistributedEncodingTrait>(layout)) {
result = distributed.toLinearLayout(shape);
} else {
auto shared = dyn_cast<SharedEncodingAttr>(layout);
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
#shared = #ttg.linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]], block = []}>
#trasposed_in_regs = #ttg.linear<{register = [[1, 0], [2, 0], [0, 1], [0, 2], [0, 4]], lane = [[0, 8], [0, 16], [0, 32], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @custom_shared_layout(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked1>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>

%opA = ttg.convert_layout %arg0 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>

%trasposed_in_regs = ttg.convert_layout %arg1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #trasposed_in_regs>
%allocated = ttg.local_alloc %trasposed_in_regs : (tensor<128x64xf16, #trasposed_in_regs>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>
%opB = ttg.local_load %allocated : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>

%result = tt.dot %opA, %opB, %cst_0 : tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

auto sharedLayout = cast<SharedEncodingAttr>(aTensorTy.getEncoding());
auto sharedLayout = dyn_cast<SharedEncodingAttr>(aTensorTy.getEncoding());
if (!sharedLayout) {
// this converter supports only shared encoding
return Value();
}
auto order = sharedLayout.getOrder();
assert((rank == 2 || order[2] == 0) &&
"expect batch to be the slowest dimension");
Expand Down

0 comments on commit a91f8e9

Please sign in to comment.