Skip to content

Commit

Permalink
Pass to block dynamic dimensions of operands of `iree_linalg_ext.atte…
Browse files Browse the repository at this point in the history
…ntion`.

The use of `IntegerRangeAnalysis` and `IntegerDivisibilityAnalysis`
gives range and divisibility information for constants passed to the
dispatch. This can be used to infer the range and divisibility
information for all tensor values in the dispatch. This PR adds an
analysis to do this.

This analysis is then used to expand the dimensions of operands of the
attention operation that are dynamic, but are known to be divisible by
a compile-time static value. This gets the operations into a form that
can be compiled by the AMDGPU backend and target the mfma intrinsics.

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Oct 23, 2024
1 parent 2d64ab1 commit 949f383
Show file tree
Hide file tree
Showing 9 changed files with 740 additions and 1 deletion.
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ iree_compiler_cc_library(
name = "Common",
srcs = [
"AddFastMathFlags.cpp",
"BlockDynamicDimensions.cpp",
"BubbleUpOrdinalOps.cpp",
"BufferizationAnalysis.cpp",
"BufferizeCopyOnlyDispatchesPass.cpp",
Expand Down Expand Up @@ -137,6 +138,7 @@ iree_compiler_cc_library(
"RemoveSingleIterationLoop.cpp",
"ReplaceSlowMinMaxOps.cpp",
"SplitFullPartialTransferPass.cpp",
"TensorDynamicDimAnalysis.cpp",
"TensorToVectorVectorizePad.cpp",
"TestExecutablePreprocessing.cpp",
"TestPartitionableLoopsInterface.cpp",
Expand All @@ -155,6 +157,7 @@ iree_compiler_cc_library(
"ExtractAddressComputation.h",
"PassUtils.h",
"Passes.h",
"TensorDynamicDimAnalysis.h",
"TileSizeSelection.h",
"Transforms.h",
"UserConfig.h",
Expand Down
302 changes: 302 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-block-dynamic-dimensions"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_BLOCKDYNAMICDIMENSIONSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

using TensorDivisibilityInfo =
llvm::SmallDenseMap<unsigned, IREE::Util::ConstantIntDivisibility>;

namespace {

struct RemoveOptimizationBarrier
: public OpRewritePattern<IREE::Util::OptimizationBarrierOp> {
using OpRewritePattern<IREE::Util::OptimizationBarrierOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp,
PatternRewriter &rewriter) const override {
rewriter.replaceOp(barrierOp, barrierOp.getOperands());
return success();
}
};

/// This pass is used to materialize information about dynamic dimensions of
/// `tensor` operands of an operation in the IR. If a dynamic dimension is
/// known to be a multiple of a compile-time constant value, this pass
/// expands the shape of the operands. For example if a `tensor` operand
/// is of shape `tensor<...x?x...>` and that dimension is known to be a
/// multiple of 16, this operand is expanded to `tensor<...x?x16x...>` where the
/// size of the new dynamic dimension is 1/16-th the size of the original
/// dynamic dimension size. This is done in two steps. 1) Replace operands with
/// such dynamic dimension with the result of a
/// `tensor.expand_shape/tensor.collapse_shape` pair
/// to materialize the new static dimension and immediately fold it away. A
/// optimization barrier is added in between to prevent these operations from
/// being folded.
/// 2) Use patterns that propagate the `tensor.collapse_shape` down to
/// manipulate the operation appropriately. This
/// allows re-using the (fairly complex) logic used to expand dimensions of
/// operations implemented in the propagation patterns.
/// At the end of the pass the optimization barriers are removed to fold away
/// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns.
struct BlockDynamicDimensionsPass final
: impl::BlockDynamicDimensionsPassBase<BlockDynamicDimensionsPass> {
void runOnOperation() override;
};
} // namespace

/// Retrieve the divisibility information for dynamic dimensions of `v` if
/// known.
static TensorDivisibilityInfo
getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis,
Value v) {
TensorDivisibilityInfo divisibilityInfo;
auto tensorType = dyn_cast<RankedTensorType>(v.getType());
if (!tensorType) {
return divisibilityInfo;
}

for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) {
if (!tensorType.isDynamicDim(index))
continue;
std::optional<IREE::Util::ConstantIntDivisibility> dimDivisibility =
dynamicDimAnalysis.getDivisibilityInfo(v, index);
if (!dimDivisibility)
continue;
divisibilityInfo[index] = std::move(dimDivisibility.value());
}

return divisibilityInfo;
}

/// For a `v` if the dimension is known to be multiple of a compile-time static
/// value, insert
///
/// ```mlir
/// %v_expand = tensor.expand_shape %v
/// %barrier = util.optimization.barrier %v
/// %v_collapse = tensor.collapse_shape %barrier
/// ```
///
/// where the generated `tensor.expand_shape` and `tensor.collapse_shape` are
/// inverses of each other. The `util.optimization.barrier` avoid these from
/// getting folded away during reshape propagation. Return the result of the
/// `tensor.collapse_shape generated.
static std::optional<Value>
blockDynamicDimensionsOfValue(RewriterBase &rewriter,
const TensorDivisibilityInfo &divisibilityInfo,
Value v) {
auto tensorType = dyn_cast<RankedTensorType>(v.getType());
if (!tensorType) {
return std::nullopt;
}

// Check if we know that the operands have a divisibility information.
SmallVector<OpFoldResult> outputShape;
SmallVector<ReassociationIndices> reassociation;
Location loc = v.getLoc();

for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) {
reassociation.emplace_back(ReassociationIndices{});

// Check if this needs division.
if (!tensorType.isDynamicDim(index) || !divisibilityInfo.contains(index)) {
reassociation.back().push_back(outputShape.size());
outputShape.push_back(rewriter.getIndexAttr(dim));
continue;
}

// Split the dynamic based on the divisibility info.
IREE::Util::ConstantIntDivisibility currDivisibility =
divisibilityInfo.lookup(index);
uint64_t factor = currDivisibility.sdiv();
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
AffineExpr divExpr = s0.floorDiv(factor);
Value sourceDim = rewriter.create<tensor::DimOp>(loc, v, index).getResult();
OpFoldResult newDynamicDim = affine::makeComposedFoldedAffineApply(
rewriter, loc, divExpr, ArrayRef<OpFoldResult>{sourceDim});
OpFoldResult newStaticDim = rewriter.getIndexAttr(factor);

reassociation.back().push_back(outputShape.size());
reassociation.back().push_back(outputShape.size() + 1);

outputShape.push_back(newDynamicDim);
outputShape.push_back(newStaticDim);
}

auto staticOutputShape =
llvm::map_to_vector(outputShape, [](OpFoldResult ofr) {
if (auto staticShapeAttr = dyn_cast<Attribute>(ofr)) {
return cast<IntegerAttr>(staticShapeAttr).getInt();
}
return ShapedType::kDynamic;
});
auto outputType = RankedTensorType::get(
staticOutputShape, tensorType.getElementType(), tensorType.getEncoding());

Value expandShape = rewriter.create<tensor::ExpandShapeOp>(
loc, outputType, v, reassociation, outputShape);
Value barrier =
rewriter.create<IREE::Util::OptimizationBarrierOp>(loc, expandShape)
.getResult(0);
Value collapseShape = rewriter.create<tensor::CollapseShapeOp>(
loc, tensorType, barrier, reassociation);
return collapseShape;
}

/// For an operation, replace the operands at indices specified in
/// `limitToOperandIndices` with the result of
/// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the
/// information about dynamic dimensions that are known to be a multiple of a
/// compile-time static value. For example,
///
/// ```mlir
/// %1 = <some_op>(..., %0, ...) : ... , tensor<4x?x6xf32>
/// ```
///
/// If the dynamic dimension is known to be a multiple of 16, then generate
///
/// ```mlir
/// %expanded = tensor.expand_shape %0
/// %barrier = util.optimization.barrier %expanded
/// %collapsed = tensor.collapse_shape %barrier
/// %1 = <some_op>(..., %collaped, ...) : ... , tensor<4x?x6xf32>
/// ```
static LogicalResult blockDynamicDimensions(
RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis,
Operation *operation, llvm::SmallDenseSet<int64_t> limitToOperandIndices) {
OpBuilder::InsertionGuard g(rewriter);

bool addedReshape = false;
for (OpOperand &operand : operation->getOpOperands()) {
if (!limitToOperandIndices.contains(operand.getOperandNumber()))
continue;
if (operand.get().getDefiningOp<tensor::CollapseShapeOp>())
continue;
TensorDivisibilityInfo operandDivisibilityInfo =
getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get());
if (operandDivisibilityInfo.empty())
continue;
std::optional<Value> newOperand = blockDynamicDimensionsOfValue(
rewriter, operandDivisibilityInfo, operand.get());
if (newOperand) {
addedReshape = true;
rewriter.modifyOpInPlace(operation,
[&]() { operand.set(newOperand.value()); });
}
}
return success(addedReshape);
}

/// Insert `tensor.expand_shape` operations to materialize in IR information
/// about dynamic dimensions that are known to be a multiple of a compile-time
/// know value, for the operands of `iree_linalg_ext.attention` operation.
static LogicalResult
blockDynamicDimensions(RewriterBase &rewriter,
const TensorDynamicDimAnalysis &dynamicDimAnalysis,
IREE::LinalgExt::AttentionOp attentionOp) {
// Only block the q and k values.
llvm::SmallDenseSet<int64_t> prunedOperandsList;
prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber());
prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber());
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp,
prunedOperandsList);
}

void BlockDynamicDimensionsPass::runOnOperation() {
Operation *operation = getOperation();
MLIRContext *context = &getContext();
TensorDynamicDimAnalysis dynamicDimAnalysis(operation);
if (failed(dynamicDimAnalysis.run())) {
return signalPassFailure();
}

IRRewriter rewriter(context);
auto walkResult = operation->walk(
[&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult {
rewriter.setInsertionPoint(attentionOp);
return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
attentionOp);
});
if (walkResult.wasInterrupted()) {
return signalPassFailure();
}

LLVM_DEBUG({
llvm::dbgs() << "After blocking dimensions:\n";
operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n";
});

{
RewritePatternSet bubbleExpandShapePatterns(context);
// Add patterns to "push down" the `tensor.collapse_shape` patterns (which
// are the dual of the patterns to "bubble up" `tensor.expand_shape`
// patterns)
linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; };
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
controlFn);
IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
bubbleExpandShapePatterns, controlFn);
// Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and
// "pushed-down" `tensor.collapse_shape` operation with their interface
// bindings or `tensor.empty` operations.
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
// Add some additional patterns that can simplify the IR and remove dead
// operations.
memref::populateResolveRankedShapedTypeResultDimsPatterns(
bubbleExpandShapePatterns);
populateRemoveDeadMemAllocPatterns(bubbleExpandShapePatterns);
if (failed(applyPatternsAndFoldGreedily(
operation, std::move(bubbleExpandShapePatterns)))) {
operation->emitOpError(
"failed in application of bubble up expand shape patterns");
return signalPassFailure();
}
}

LLVM_DEBUG({
llvm::dbgs() << "After reshape propagation:\n";
operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n";
});

// Delete the optimization barrier and run some further cleanup.
{
RewritePatternSet removeBarrierOpsPatterns(context);
removeBarrierOpsPatterns.insert<RemoveOptimizationBarrier>(context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,
context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(
removeBarrierOpsPatterns, context);
if (failed(applyPatternsAndFoldGreedily(
operation, std::move(removeBarrierOpsPatterns)))) {
operation->emitOpError("failed in cleanup patterns");
return signalPassFailure();
}
}

return;
}

} // namespace mlir::iree_compiler
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ iree_cc_library(
"ExtractAddressComputation.h"
"PassUtils.h"
"Passes.h"
"TensorDynamicDimAnalysis.h"
"TileSizeSelection.h"
"Transforms.h"
"UserConfig.h"
SRCS
"AddFastMathFlags.cpp"
"BlockDynamicDimensions.cpp"
"BubbleUpOrdinalOps.cpp"
"BufferizationAnalysis.cpp"
"BufferizeCopyOnlyDispatchesPass.cpp"
Expand Down Expand Up @@ -128,6 +130,7 @@ iree_cc_library(
"RemoveSingleIterationLoop.cpp"
"ReplaceSlowMinMaxOps.cpp"
"SplitFullPartialTransferPass.cpp"
"TensorDynamicDimAnalysis.cpp"
"TensorToVectorVectorizePad.cpp"
"TestExecutablePreprocessing.cpp"
"TestPartitionableLoopsInterface.cpp"
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def AddFastMathFlagsPass
"given a floating-point mode.";
}

def BlockDynamicDimensionsPass
: Pass<"iree-codegen-block-dynamic-dimensions"> {
let summary = "Expand dynamic dimensions that are known to be multiples of "
"statically known values.";
}

def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> {
let summary = "Bubbles op ordinal ops to allow for workgroup count computation";
let description = [{
Expand Down
Loading

0 comments on commit 949f383

Please sign in to comment.