Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support shared encoding defined with linear layout #5720

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Jan 27, 2025

This PR enables basic support for linear layout used as a shared encoding.

This change is needed to support custom memory layout introduced in #4984

@binarman
Copy link
Contributor Author

+cc @lezcano

AMD team wants to experiment with non standard swizzling patterns(for example #4984), so I had an idea to use LinearEncodingAttr as a memory layout. While implementation I've noticed that it inherits DistributedEncoding.

What do you think about renaming LinearEncodingAttr to DistributedLinearEncodingAttr and introduction of new MemoryLinearEncodingAttr encoding?

@masahi
Copy link
Collaborator

masahi commented Jan 27, 2025

I'm interested in this discussion as well. The NV lowering path heavily relies on the legacy encoding of the SMEM layout, so going all in on using LL for SMEM at the IR level is difficult. Moreover, how are you going to query swizzling properties of SMEM represented only via LL?

@Jokeren
Copy link
Contributor

Jokeren commented Jan 27, 2025

I'm interested in this discussion as well. The NV lowering path heavily relies on the legacy encoding of the SMEM layout, so going all in on using LL for SMEM at the IR level is difficult. Moreover, how are you going to query swizzling properties of SMEM represented only via LL?

We can query the first few bases of the offset dimension to see if they are contiguous and check if the rest bases do not overlap with them.

There's indeed an algorithm we plan to implement for the ldmatrix path that won't rely on the standard shared encoding

@Jokeren
Copy link
Contributor

Jokeren commented Jan 27, 2025

AMD path might be much actually simpler since it doesn't slice shared memory and doesn't have ldmatrix/stmatrix instructions.

The only concern is that checking if base names have "offset" to determine a shared encoding isn't a solid solution.

@lezcano
Copy link
Contributor

lezcano commented Jan 28, 2025

@masahi just bumped into this quite recently as well, and as @Jokeren has mentioned, we'll start using new shmem layouts sooner than later.

The issue with SharedLayouts is that they don't have an API like the one DistributedLayouts have. They just have very few attributes that are rather unique to their own structure. In general, a characterisation of the shared memory layout that we care about may be given by is a LinearLayout for which al its bases have at most 2 bits equal to one (popc(b) <= 2). Also, the only case in which we have broadcasting (a basis that is a zero) is the mxfp4 in BW case (we don't model it with a basis equal to zero yet, but we'll do so in the future). In all the other cases these LLs are bijective maps onto their domain.

Taking this structure into account, the tricky part here is to create an API that's returns the relevant properties we need to work with this layout. This would be the equivalent to

SmallVector<unsigned> basesPerDimImpl(const LinearLayout::BasesT &namedBases,
StringAttr dimName, size_t rank,
bool skipBroadcast = true) {
const auto &bases = namedBases.find(dimName)->second;
if (bases.empty()) {
return SmallVector<unsigned>(rank, 1);
}
SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = 0;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
// Skip a basis if it's broadcasting (all zeros)
// e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout)
if (it != basis.end()) {
nonZeroIdx = it - basis.begin();
ret[nonZeroIdx] *= 2;
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
ret[nonZeroIdx] *= 2;
}
}
return ret;
}
SmallVector<unsigned>
LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const {
auto ll = getLinearLayout();
auto rank = ll.getNumOutDims();
return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast);
}
SmallVector<unsigned>
LinearEncodingAttr::orderPerDim(StringAttr dimName,
ArrayRef<unsigned> defaultOrder) const {
auto ll = getLinearLayout();
const auto &bases = ll.getBases().find(dimName)->second;
llvm::SetVector<unsigned> order;
auto nonZero = [](auto val) { return val != 0; };
for (const auto &basis : bases) {
// Bases can have one or zero non-zero elements
// Skip a basis if it's broadcasting (all zeros)
// e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout)
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
if (it != basis.end()) {
auto i = it - basis.begin();
order.insert(i);
}
}
// If any dim is missing, we add them in the defaultOrder
for (auto i : defaultOrder) {
order.insert(i);
}
return SmallVector<unsigned>(order.begin(), order.end());
}
// [Note. Divergence of methods wrt. legacy layouts]
// For smaller shapes where the CTATile is larger than the output
// tensor, some methods return different values than the legacy layouts. I think
// this is benign tho. An example: what is the the vector of `warpsPerCTA` if
// all the warps hold the same data? I think it should be [1, 1], even if we
// have 4 warps. But perhaps for this we have to add some masking in some
// places... We'll see
SmallVector<unsigned> LinearEncodingAttr::getRepOrder() const {
// This is not correct, but:
// - It happens to agree in most places with the legacy layout
// - getRepOrder does not make sense for LinearEncodingAttr as it already has
// the same shape as the tensor that uses it
return getOrder();
}
SmallVector<unsigned> LinearEncodingAttr::getCTAsPerCGA() const {
// CTAs are split into an identity part (SplitNum) and a broadcast part
return basesPerDim(StringAttr::get(getContext(), "block"),
/*skipBroadcast=*/false);
}
SmallVector<unsigned> LinearEncodingAttr::getCTAOrder() const {
return orderPerDim(StringAttr::get(getContext(), "block"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getCTASplitNum() const {
return basesPerDim(StringAttr::get(getContext(), "block"));
}
SmallVector<unsigned> LinearEncodingAttr::getWarpsPerCTA() const {
return basesPerDim(StringAttr::get(getContext(), "warp"));
}
SmallVector<unsigned> LinearEncodingAttr::getWarpOrder() const {
return orderPerDim(StringAttr::get(getContext(), "warp"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getThreadsPerWarp() const {
return basesPerDim(StringAttr::get(getContext(), "lane"));
}
SmallVector<unsigned> LinearEncodingAttr::getThreadOrder() const {
return orderPerDim(StringAttr::get(getContext(), "lane"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getSizePerThread() const {
auto rank = getRepOrder().size();
auto ll = getLinearLayout();
auto ctx = getContext();
auto kRegister = StringAttr::get(ctx, "register");
// We canonicalize on the spot, as if we use CGAs the regs are not in
// canonical form The order is [reg, lane, warp, rep, block], so we first
// remove the blocks
llvm::SmallVector<unsigned> ctaShape;
for (auto [shape, cgaNum] :
llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) {
ctaShape.push_back(shape / cgaNum);
}
LinearLayout::BasesT bases = ll.getBases();
llvm::SetVector<unsigned> reverseRepOrder;
auto nonZero = [](auto val) { return val != 0; };
auto &registers = bases[StringAttr::get(ctx, "register")];
while (!registers.empty()) {
auto &basis = registers.back();
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// If there's broadcasting (base == zeros) there are no more reps
if (it == basis.end()) {
break;
}
auto dim = it - basis.begin();
reverseRepOrder.insert(dim);
// As soon as we stop finding reps, we stop
if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) {
break;
}
ctaShape[dim] /= 2;
registers.pop_back();
}
return basesPerDimImpl(bases, kRegister, rank);
}
SmallVector<unsigned> LinearEncodingAttr::getOrder() const {
auto rank = getLinearLayout().getNumOutDims();
SmallVector<unsigned> order(rank);
// Choose [rank-1, rank-2, ... 0] as the default order in case
// there are dims that do not move in the register
// This order is as good as any really
std::iota(order.rbegin(), order.rend(), 0);
return orderPerDim(StringAttr::get(getContext(), "register"), order);
}
LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ll = getLinearLayout();
auto canonicalDims = llvm::to_vector(ll.getOutDimNames());
llvm::SmallDenseMap<StringAttr, int64_t> namedShape;
llvm::SmallVector<StringAttr> permutedDims;
for (auto dim : getRepOrder()) {
permutedDims.push_back(canonicalDims[dim]);
namedShape[canonicalDims[dim]] = shape[dim];
}
ll = ll.transposeOuts(permutedDims);
ll = ensureLayoutNotSmallerThan(ll, namedShape);
ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false);
ll = ll.transposeOuts(canonicalDims);
return ll;
}
SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// When broadcasting the layout the shape changes, otherwise the shape is
// the same as the shape of the tensor
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto scaledLayout = get(getContext(), toLinearLayout(shape));
auto kRegister = StringAttr::get(getContext(), "register");
return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false);
}

This API is going to be very different to the one from LinearEncodingAttr, so they should be separate, and I expect it to be also a bit different to the one in SharedEncoding. As such, it would be best to create a parent class (LinearSharedEncodingAttr?) that just holds a LinearLayout and make SharedEncoding inherit from it. From here, the way forward would not be simply to implement getVec, getPerPhase, getMaxPhase etc, but actually change passes, aux functions, etc to support LinearSharedEncodingAttr in its full generality (i.e., make them to work with an arbitrary invertible (probably easiest not to think about the mxfp4 case for now) linear layout with popc(b) in (1, 2). You'll have to come up with your own API for this, which may be tricky.

This is not the easiest task, but it's what we'll need to tackle if we want to support generic swizzled layouts. If this is a bit too much in one go, that's fine. Another way to go is to define another subclass of SharedEncodings that are less general than the one I described above, and start implementing that one and adding support across the codebase for it. That is what @masahi is set to do with mxfp4 for BW.

I'll probably make some strides towards tackling the general case next month anyway.

@binarman
Copy link
Contributor Author

It seems currently I do not fully understand all issues related to general SMEM linear layout encoding.
In my mind shared layout should have only one method: convert to linear layout, which should be trivial in case of linear layout SMEM encoding.

Since we have #5764, which I suppose will be merged soon, I will try go in two directions simultaneously:

@lezcano
Copy link
Contributor

lezcano commented Jan 31, 2025

That PR is being reworked following Thomas' advice. The final state will be rather similar to what you are saying: a common class with no methods. In particular, you can implement toLinearLayout, and generally, you should support any methods you are interested on supporting by inspecting the corresponding linear layout (and nothing else). The issue is that a generic linear layout is a rather complex object, so that last point is easier said than done (this is why we need a characterisation of the structure of the layouts we're interested in), but yeah.

This PR enables basic support for linear layout used as a shared encoding.
@masahi
Copy link
Collaborator

masahi commented Jan 31, 2025

That PR is being reworked following Thomas' advice

@binarman Yes, the rework is going to take some time. As Thomas suggested, I'll first add SharedEncodingTrait and SwizzleSharedEncodingAttr. You can assume that SharedEncodingTrait is mostly an empty interface, from which maybe you can derive an another interface for SMEM encoding backed by LinearLayout. To avoid putting myself in your critical path, I should probably make a small PR to introduce just bare SharedEncodingTrait.

@binarman binarman force-pushed the shared_linear_layout branch from a91f8e9 to c24e6bf Compare January 31, 2025 22:11
@ThomasRaoux
Copy link
Collaborator

That PR is being reworked following Thomas' advice

@binarman Yes, the rework is going to take some time. As Thomas suggested, I'll first add SharedEncodingTrait and SwizzleSharedEncodingAttr. You can assume that SharedEncodingTrait is mostly an empty interface, from which maybe you can derive an another interface for SMEM encoding backed by LinearLayout. To avoid putting myself in your critical path, I should probably make a small PR to introduce just bare SharedEncodingTrait.

Actually let me take this piece, I'll try to send something to unblock both of you soon

@ThomasRaoux
Copy link
Collaborator

ok I ended up doing a bit more refactoring, here is the PR:
#5786

That should allow extending the new nivida mma shared layout as well as exposing linear layout the same way we do for distributed layout

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants