Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAYOUTS] Generalise HoistLayoutConversion to work with arbitrary layouts and chains of ops #5673

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
@@ -830,7 +830,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
Elementwise,
SameOperandsAndResultEncoding,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>
]> {
let summary = "inline assembly applying an elementwise operation to a group of packed elements.";
let description = [{
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
@@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {

def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
TransposeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
InferTypeOpWithLayoutEquivalence,
SameOperandsAndResultElementType]> {
let summary = "transpose the descriptor";

6 changes: 6 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -1037,6 +1037,12 @@ void ElementwiseInlineAsmOp::getEffects(
SideEffects::DefaultResource::get());
}

Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() {
if (getPure())
return Speculation::Speculatable;
return Speculation::NotSpeculatable;
}

LogicalResult ElementwiseInlineAsmOp::verify() {
if (getNumOperands() >= 1) {
auto tensorType = dyn_cast<RankedTensorType>(getOperand(0).getType());
28 changes: 15 additions & 13 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -463,15 +463,17 @@ OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
return {};
}

LogicalResult MemDescTransOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
LogicalResult
MemDescTransOp::inferReturnTypes(MLIRContext *context,
std::optional<Location> location,
MemDescTransOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {

// type is the same as the input
auto argTy = cast<MemDescType>(operands[0].getType());
auto argShape = argTy.getShape();
auto order = properties.as<Properties *>()->order.asArrayRef();
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);
auto argTy = cast<MemDescType>(adaptor.getSrc().getType());
auto shape = argTy.getShape();
auto order = adaptor.getOrder();
SmallVector<int64_t> retShape = applyPermutation(shape, order);

auto retEltTy = argTy.getElementType();
Attribute argEncoding = argTy.getEncoding();
@@ -480,17 +482,17 @@ LogicalResult MemDescTransOp::inferReturnTypes(
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferTransOpEncoding(argEncoding, argShape, order, retEncoding)
->inferTransOpEncoding(argEncoding, shape, order, retEncoding)
.failed()) {
return failure();
}
}
auto memDescTy = cast<MemDescType>(argTy);
inferredReturnTypes.push_back(MemDescType::get(
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(),
memDescTy.getMutableMemory()));
inferredReturnTypes.push_back(
MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(),
argTy.getMutableMemory()));
return success();
}

// LocalAllocOp
void LocalAllocOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
142 changes: 0 additions & 142 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
@@ -24,36 +24,6 @@ namespace {
// Roughly, whether op is elementwise and thus threads don't need
// to exchange elements. But some ops are not currently supported even though
// they meet that criterion.
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
!isa<arith::ArithDialect>(op->getDialect()))
return false;

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
return false;

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
return false;

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(op)) {
Type opType = getElementTypeOrSelf(op->getOperand(0));
if (opType.isInteger(1))
return false;
}

return true;
}

// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
// is in registers).
bool canHoistDotOpEncV3(Operation *op) {
@@ -195,116 +165,6 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
}
};

// Move convert-to-dot-operand "up" past elementwise ops:
//
// convert(elementwise(x)) #dot_operand ->
// elementwise(convert(x, #dot_operand)).
//
// The goal is to put the convert right next to the originating load. If we can
// accomplish this, then we can save a shmem round-trip:
//
// Before:
//
// - Load from global into shmem using an async copy.
// - Load from shmem into a #blocked layout.
// - Do elementwise ops over #blocked layout.
// - Convert to #dot_operand (round-trip through shmem).
// - Do dot.
//
// After:
//
// - Load from global into shmem using an async copy (same as before).
// - Load from shmem into a #dot_operand layout.
// - Do elementwise ops over #dot_operand layout.
// - Do dot.
//
// This can also be propagated when we have a constant, instead of a load.
//
// Eliminating the shmem round-trip is such a big win, we're willing to do it
// even if this duplicates work because some of the elementwise ops have uses
// that don't flow into the dot. On the other hand, we only want to do this if
// we can in fact reduce shmem round-trips: For example, simply moving a convert
// up above e.g. an `add` now means we have *two* converts. That's worse,
// unless we can continue moving the converts upwards and eventually merge them.
// So we try to check that this will be beneficial before making any changes.
class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ConvertLayoutOp cvt,
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1)
return failure();

auto srcTy = dyn_cast<RankedTensorType>(src->getResult(0).getType());
if (!srcTy)
return failure();

if (!all_of(src->getOperandTypes(),
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

if (!canHoistDotOpEncV2(src, dotOpEnc))
return failure();

// Check that the conversion is transitively dependent on a load or a
// constant, and all operations between it and the convert are layout
// preserving.
//
// TODO(jlebar): This is accidentally quadratic; we iterate over the whole
// slice but then at the end we only modify one op!
SetVector<Operation *> slice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(cvt.getOperation(), &slice, opt);

// TODO(jlebar): This is too conservative when there are multiple loads in
// the chain. If one of the loads has a non-layout-preserving op and the
// other does not, then we may or may not accept the chain, depending on
// which load gets hit first by getBackwardSlice. For example:
// cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
// cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
bool foundInitializer = false;
// Reverse the slice so that we start directly above the convert and check
// that every op allows hoisting until we find a load or a constant.
for (Operation *currOp : llvm::reverse(slice)) {
if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
foundInitializer = true;
break;
}
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
return failure();
}
if (!foundInitializer)
return failure();

SmallVector<ConvertLayoutOp> newOperands;
for (auto operand : src->getOperands()) {
// We checked earlier that all operands are ranked tensors.
auto operandTy = cast<RankedTensorType>(operand.getType());
Type newCvtTy = RankedTensorType::get(
srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding());
newOperands.push_back(
rewriter.create<ConvertLayoutOp>(cvt.getLoc(), newCvtTy, operand));
}
auto newRet = rewriter.clone(*src);
for (int i = 0; i < newOperands.size(); i++)
newRet->setOperand(i, newOperands[i]);
newRet->getResult(0).setType(RankedTensorType::get(
srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding()));

rewriter.replaceOp(cvt, newRet->getResults());
return success();
}
};

// Rewrite
//
// dot(alloc(trans() #shared1) ->
@@ -699,8 +559,6 @@ class TritonGPUOptimizeDotOperandsPass
mlir::RewritePatternSet patterns(context);
patterns.add<MMAV3HoistLayoutConversion>(context);
patterns.add<SwizzleShmemConvert>(context);
if (this->hoistLayoutConversion.getValue())
patterns.add<HoistLayoutConversion>(context);
patterns.add<FuseTransMMAV3Plus>(context);
patterns.add<MMAV3UseRegOperand>(context);
patterns.add<InjectTMemCopy>(context);
Loading