Skip to content

Commit

Permalink
introduce separate layout for linearShared
Browse files Browse the repository at this point in the history
  • Loading branch information
binarman committed Jan 31, 2025
1 parent 1a3ecd0 commit c24e6bf
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 71 deletions.
28 changes: 28 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,34 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Shared Linear Encoding
//===----------------------------------------------------------------------===//

def SharedLinearEncodingAttr : SharedEncoding<"SharedLinearEncoding", "shared_linear_encoding"> {
let mnemonic = "shared_linear";

let description = [{
See the docs in LinearLayout.h for the definition of linear layouts.
}];

let parameters = (
ins
LinearLayoutParam:$linearLayout,
"int32_t":$alignment__
);

let extraClassDeclaration = extraSharedEncodingDeclaration;

let genVerifyDecl = 1;
// Example of assembly format:
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
// warp = [[16, 0], [32, 0]],
// block = []}>
let hasCustomAssemblyFormat = 1;
}


//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
// We propose case 2 (see comments below), which provides a more general
// solution for all swizzled shared memory scenarios, including the edge case
// mentioned above.
if (sharedEnc &&
if (mlir::isa<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding()) &&
sharedEnc &&
isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout,
{{kRegister, regId},
Expand Down
288 changes: 245 additions & 43 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1189,18 +1189,6 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
// Linear Encoding
//===----------------------------------------------------------------------===//

namespace {

bool isMemoryEncoding(const mlir::triton::LinearLayout &ll) {
return (ll.getInDimNames().begin()->str() == "offset");
}

bool isDistributedEncoding(const mlir::triton::LinearLayout &ll) {
return (ll.getInDimNames().begin()->str() == "register");
}

} // namespace

// FIXME Can we take the LinearLayout by const&?
LogicalResult
LinearEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Expand All @@ -1210,20 +1198,12 @@ LinearEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
// warp = [[16, 0], [32, 0]],
// block = []}>
// <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4,
// 0]],
// block = []}>
// The input dims must be {register, lane, warp, block} or {offset, block}
// The input dims must be {register, lane, warp, block}
// The output dims of the linear layout should be dim0..dim[rank-1]

SmallVector<std::string> expectedInDims;
SmallVector<std::string> expectedInDims = {"register", "lane", "warp",
"block"};
// check if this is a memory of distributed layout
if (isMemoryEncoding(linearLayout)) {
expectedInDims = {"offset", "block"};
} else {
assert(isDistributedEncoding(linearLayout));
expectedInDims = {"register", "lane", "warp", "block"};
}

for (const auto &[i, dims] : llvm::enumerate(
llvm::zip(linearLayout.getInDimNames(), expectedInDims))) {
Expand Down Expand Up @@ -1292,20 +1272,12 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) {
LinearLayout::BasesT bases;

// Parse the basis names in order (the order is relevant)
std::vector<std::string> inDimNames = {"offset", "register", "lane", "warp",
"block"};
std::vector<std::string> inDimNames = {"register", "lane", "warp", "block"};

for (const auto &inDimNameStr : inDimNames) {
auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr);
Attribute value = dict.get(inDimName);

// there are two cases:
// memory layout: {offset, block}
// and distributed layout: {register, lane, warp, block}
// this check skips redundant fields
if (!value)
continue;

// Expecting an array of arrays
auto arrayOfArraysAttr = mlir::dyn_cast<ArrayAttr>(value);
if (!arrayOfArraysAttr) {
Expand Down Expand Up @@ -1513,24 +1485,14 @@ SmallVector<unsigned> LinearEncodingAttr::getOrder() const {
// This order is as good as any really
std::iota(order.rbegin(), order.rend(), 0);

if (isMemoryEncoding(linearLayout)) {
return orderPerDim(StringAttr::get(getContext(), "offset"), order);
} else {
assert(isDistributedEncoding(linearLayout));
return orderPerDim(StringAttr::get(getContext(), "register"), order);
}
return orderPerDim(StringAttr::get(getContext(), "register"), order);
}

LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ll = getLinearLayout();
auto canonicalDims = llvm::to_vector(ll.getOutDimNames());
llvm::SmallDenseMap<StringAttr, int64_t> namedShape;
llvm::SmallVector<StringAttr> permutedDims;
if (isMemoryEncoding(ll)) {
// memory encoding should be 1-to-1 compatible with tensor size
// TODO: check layout size equal to shape
return ll;
}
for (auto dim : getRepOrder()) {
permutedDims.push_back(canonicalDims[dim]);
namedShape[canonicalDims[dim]] = shape[dim];
Expand Down Expand Up @@ -1582,6 +1544,246 @@ unsigned LinearEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return product(getElemsPerThread(shape, eltTy));
}

//===----------------------------------------------------------------------===//
// Shared Linear Encoding
//===----------------------------------------------------------------------===//

LogicalResult
SharedLinearEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LinearLayout linearLayout, int32_t alignment) {
// Examples of LinearEncodingAttr
// <{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0]], block = []}>
// The input dims must be {offset, block}
// The output dims of the linear layout should be dim0..dim[rank-1]

SmallVector<std::string> expectedInDims = {"offset", "block"};

for (const auto &[i, dims] : llvm::enumerate(
llvm::zip(linearLayout.getInDimNames(), expectedInDims))) {
const auto &[dim, expectedDimStr] = dims;
if (dim.str() != expectedDimStr) {
return emitError() << "Expected input dimension " << i << " to be '"
<< expectedDimStr << "'. Got " << dim;
}
}

// outDims are ['dim0', 'dim1', ...]
for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) {
if (dim.str() != ("dim" + llvm::Twine(i)).str()) {
return emitError()
<< "Expected output dimensions to be ['dim0', 'dim1', ...]. Got "
<< dim << " at position " << i;
}
}

const auto &bases = linearLayout.getBases();
auto nonZero = [](auto val) { return val != 0; };
for (const auto &dimBases : llvm::make_second_range(bases)) {
if (!llvm::all_of(dimBases, [&](const auto &basis) {
return std::count_if(basis.begin(), basis.end(), nonZero) <= 1;
})) {
return emitError()
<< "In a distributed layout, each base must move in at most one "
"dimension.";
}
}

return success();
}

void SharedLinearEncodingAttr::print(mlir::AsmPrinter &printer) const {
// We don't use the default implementation as it's a bit too verbose
// This prints in the following format that is shape agnostic, in the sense
// that we don't print explicitly the outShape of the LL
// We always assume LLs to be surjective
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
// warp = [[16, 0], [32, 0]],
// block = []}>
auto ll = getLinearLayout();
printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) {
return base.first.str() + " = " + "[" +
join(base.second, ", ",
[](const std::vector<int32_t> &vec) {
return "[" + join(vec, ", ") + "]";
}) +
"]";
}) << "}>";
}

Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};

DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};

if (parser.parseGreater().failed())
return {};

LinearLayout::BasesT bases;

// Parse the basis names in order (the order is relevant)
std::vector<std::string> inDimNames = {"offset", "block"};

for (const auto &inDimNameStr : inDimNames) {
auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr);
Attribute value = dict.get(inDimName);

// Expecting an array of arrays
auto arrayOfArraysAttr = mlir::dyn_cast<ArrayAttr>(value);
if (!arrayOfArraysAttr) {
parser.emitError(parser.getCurrentLocation(),
"Expected array of arrays for basis of '")
<< inDimName.getValue() << "'";
return {};
}

std::vector<std::vector<int32_t>> inDimBases;
for (Attribute arrayAttr : arrayOfArraysAttr) {
auto intArrayAttr = mlir::dyn_cast<ArrayAttr>(arrayAttr);
if (!intArrayAttr) {
parser.emitError(parser.getCurrentLocation(),
"Expected array of integers in basis for '")
<< inDimName.getValue() << "'";
return {};
}
std::vector<int32_t> basis;
for (Attribute intAttr : intArrayAttr) {
auto intValueAttr = mlir::dyn_cast<IntegerAttr>(intAttr);
if (!intValueAttr) {
parser.emitError(parser.getCurrentLocation(),
"Expected integer in basis for '")
<< inDimName.getValue() << "'";
return {};
}
basis.push_back(intValueAttr.getInt());
}
inDimBases.push_back(std::move(basis));
}
bases[inDimName] = std::move(inDimBases);
}
size_t rank = 0;
for (const auto &basesDim : llvm::make_second_range(bases)) {
if (!basesDim.empty()) {
rank = basesDim[0].size();
break;
}
}

// To implement this we'd need to serialise the rank as well.
// We can do this if we ever need it
if (rank == 0) {
parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported");
return {};
}

// Generate standard outDimNames (dim0, dim1, ...)
SmallVector<StringAttr> outDimNames;
for (int i = 0; i < rank; ++i) {
outDimNames.push_back(
StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i)));
}

// Create LinearLayout
LinearLayout linearLayout(std::move(bases), std::move(outDimNames));

// Parse alignment attribute
auto alignmentStr = StringAttr::get(parser.getContext(), "alignment");
unsigned alignment = 0;
auto alignmentAttr = dict.getNamed(alignmentStr);
if (!alignmentAttr.has_value()) {
// TODO: report error
return {};
}

if (parseUInt(parser, alignmentAttr.value(), alignment, "alignment").failed())
return {};

// Create and return the LinearEncodingAttr
return parser.getChecked<SharedLinearEncodingAttr>(
parser.getContext(), std::move(linearLayout), alignment);
}

unsigned SharedLinearEncodingAttr::getVec() const {
assert(false);
return 0;
};
unsigned SharedLinearEncodingAttr::getPerPhase() const {
assert(false);
return 0;
};
unsigned SharedLinearEncodingAttr::getMaxPhase() const {
assert(false);
return 0;
};
SmallVector<unsigned> SharedLinearEncodingAttr::getOrder() const {
auto dimName = StringAttr::get(getContext(), "offset");
auto ll = getLinearLayout();
const auto &bases = ll.getBases().find(dimName)->second;

unsigned rank = bases.begin()->size();
auto defaultOrder = llvm::iota_range(0u, rank, false);
llvm::SetVector<unsigned> order;
auto nonZero = [](auto val) { return val != 0; };
for (const auto &basis : bases) {
// Bases can have one or zero non-zero elements
// Skip a basis if it's broadcasting (all zeros)
// e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout)
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
if (it != basis.end()) {
auto i = it - basis.begin();
order.insert(i);
}
}
// If any dim is missing, we add them in the defaultOrder
for (auto i : defaultOrder) {
order.insert(i);
}
return SmallVector<unsigned>(order.begin(), order.end());
};
SmallVector<unsigned> SharedLinearEncodingAttr::getCTAsPerCGA() const {
// TODO replace with block abalysis?
auto rank = getLinearLayout()
.getBasis(StringAttr::get(getContext(), "offset"), 0)
.size();
return SmallVector<unsigned>(rank, 1);
};
SmallVector<unsigned> SharedLinearEncodingAttr::getCTASplitNum() const {
// TODO replace with block abalysis?
auto rank = getLinearLayout()
.getBasis(StringAttr::get(getContext(), "offset"), 0)
.size();
return SmallVector<unsigned>(rank, 1);
};
SmallVector<unsigned> SharedLinearEncodingAttr::getCTAOrder() const {
// TODO replace with block abalysis?
auto rank = getLinearLayout()
.getBasis(StringAttr::get(getContext(), "offset"), 0)
.size();
return llvm::to_vector(llvm::iota_range<unsigned>(0, rank, false));
};
bool SharedLinearEncodingAttr::hasLeadingOffset() const {
// assert(false);
return false;
};
int32_t SharedLinearEncodingAttr::getAlignment() const {
return this->getAlignment__();
};
SmallVector<unsigned>
SharedLinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
llvm_unreachable("getElemsPerThread is not supported for shared layout");
return {};
}
unsigned
SharedLinearEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
llvm_unreachable("getElemsPerThread is not supported for shared layout");
return 0;
}

//===----------------------------------------------------------------------===//
// MMA encoding
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit c24e6bf

Please sign in to comment.