Skip to content

Commit

Permalink
[Codegen][GPU] Improve forall hoisting pattern for single trip loops (#…
Browse files Browse the repository at this point in the history
…18418)

For single trip scf.forall loops the `tensor.extract_slice` on the
output can be folded away, causing the forall loop hoisting pattern to
fail. Single trip loops with processor ID mappings cannot be folded away
because they can resolve to an `scf.if`. So this patch extends the loop
hoisting pattern to support hoisting in the case of single trip loops
where the `tensor.extract_slice` has been folded away.
  • Loading branch information
qedawkins authored Sep 7, 2024
1 parent edc5d5e commit 6c9aad0
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-gpu-fuse-and-hoist-parallel-loops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir::iree_compiler::IREE::GPU {

#define GEN_PASS_DEF_FUSEANDHOISTPARALLELLOOPSPASS
Expand Down Expand Up @@ -192,6 +196,8 @@ struct FuseTilableForallConsumers final
void FuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();

FunctionOpInterface funcOp = getOperation();

// First run the hoisting and fusion patterns.
{
RewritePatternSet patterns(context);
Expand All @@ -200,12 +206,13 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
patterns.add<FuseForalls>(context);
patterns.add<FuseTilableForallConsumers>(context);
populateForallLoopHoistingPattern(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

LDBG("After fusing and hoisting loops\n" << funcOp);

// After hoisting parallel loops, try to fuse in any newly revealed consumers
// and destinations.
// TODO: Move the consumer fusion pattern to an explicit worklist rather than
Expand All @@ -216,24 +223,26 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
patterns.add<FuseTilableForallConsumers>(context);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

LDBG("After fusing new consumers\n" << funcOp);

// Finally try to do any new producer fusions.
{
RewritePatternSet patterns(context);
patterns.add<FuseTilableDestinationProducers>(context);
patterns.add<FuseTilableSliceProducers>(context);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

LDBG("After fusing new producers\n" << funcOp);
}

} // namespace mlir::iree_compiler::IREE::GPU
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,39 @@ func.func @multi_hoist_with_other_ops_in_loop(%2: tensor<128x128xf16>, %3: tenso
// CHECK: scf.forall.in_parallel
// CHECK: scf.forall.in_parallel
// CHECK: return

// -----

func.func @hoist_with_single_trip_loops(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%empty = tensor.empty() : tensor<128x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %empty) -> (tensor<128x128xf16>) {
%9 = scf.forall (%arg2, %arg3) in (1, 1) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf16>) {
%extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3] [128, 128] [1, 1] : tensor<128x128xf16> to tensor<128x128xf16>
%10 = scf.forall (%arg5, %arg6) in (1, 1) shared_outs(%arg7 = %extracted_slice) -> (tensor<128x128xf16>) {
%16 = linalg.copy ins(%arg7 : tensor<128x128xf16>) outs(%2 : tensor<128x128xf16>) -> tensor<128x128xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [128, 128] [1, 1] : tensor<128x128xf16> into tensor<128x128xf16>
}
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [128, 128] [1, 1] : tensor<128x128xf16> into tensor<128x128xf16>
}
} {mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]}
scf.yield %9 : tensor<128x128xf16>
}
return %8 : tensor<128x128xf16>
}

// CHECK-LABEL: func @hoist_with_single_trip_loops
// CHECK-SAME: %[[I0:[A-Za-z0-9]+]]: tensor<128x128xf16>
// CHECK-SAME: %[[I1:[A-Za-z0-9]+]]: tensor<128x128xf16>
// CHECK: scf.forall
// CHECK: scf.forall
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} -> (tensor<128x128xf16>)
// CHECK: linalg.copy
// CHECK: scf.forall.in_parallel
// CHECK: scf.forall.in_parallel
// CHECK: return
87 changes: 68 additions & 19 deletions compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,10 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
rewriter.moveOpBefore(op, &forallBody->getOperations().front());
}

bool isSingleTripLoop = forallOp.isNormalized() &&
llvm::all_of(forallOp.getStaticUpperBound(),
[](int64_t i) { return i == 1; });

// Step 2. Collect the set of tensor.parallel_insert_slice ops in the
// terminator and their paired extract_slice ops from the for loop iter arg.
SmallVector<Operation *> sliceOperandProducers;
Expand All @@ -1106,7 +1110,8 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
scf::InParallelOp parallelTerminator = forallOp.getTerminator();
SmallVector<tensor::ParallelInsertSliceOp> terminators(
forallOp.getNumResults());
SmallVector<tensor::ExtractSliceOp> pairedSlices(forallOp.getNumResults());
SmallVector<std::optional<tensor::ExtractSliceOp>> pairedSlices(
forallOp.getNumResults(), std::nullopt);
int64_t numInductionVars = forallOp.getInductionVars().size();
for (auto &yieldingOp : parallelTerminator.getYieldingOps()) {
auto parallelInsert = cast<tensor::ParallelInsertSliceOp>(&yieldingOp);
Expand All @@ -1117,28 +1122,58 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
if (user == parallelInsert)
continue;
auto maybeSlice = dyn_cast<tensor::ExtractSliceOp>(user);
// Fail if the destination has more users than a direct insert and
// extract slice.
if (!maybeSlice) {
return failure();
// Fail if the destination has more users than a direct insert and
// extract slice unless it is a single trip loop.
if (!isSingleTripLoop) {
return failure();
}
continue;
}
// Require a single extract per destination.
// Require at most one extract per destination.
if (destSlice) {
return failure();
}
destSlice = maybeSlice;
}

// Verify they operate on equivalent subsets, ensuring the slices are
// hoistable. It is still possible to hoist the loop if this is not true,
// however in such cases we likely formed the loops in the wrong order.
if (!cast<SubsetOpInterface>(*destSlice)
.operatesOnEquivalentSubset(
cast<SubsetOpInterface>(*parallelInsert),
[](Value v1, Value v2) { return v1 == v2; })) {
if (destSlice && !cast<SubsetOpInterface>(*destSlice)
.operatesOnEquivalentSubset(
cast<SubsetOpInterface>(*parallelInsert),
[](Value v1, Value v2) { return v1 == v2; })) {
return failure();
}
terminators[destBbArg.getArgNumber() - numInductionVars] = parallelInsert;
pairedSlices[destBbArg.getArgNumber() - numInductionVars] = destSlice;

auto isOverwritingFullDestination =
[](tensor::ParallelInsertSliceOp insert) {
// TODO: Handle rank reducing case.
if (insert.getSourceType().getRank() !=
insert.getDestType().getRank()) {
return false;
}
for (auto [dim, size] : llvm::enumerate(insert.getMixedSizes())) {
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
{size}, {insert.getDest(), static_cast<int64_t>(dim)});
if (failed(equalDimSize) || !*equalDimSize)
return false;
}
return true;
};

// For single trip loops, verify that the parallel_insert_slice is
// overwriting the full destination.
if (!destSlice && !isOverwritingFullDestination(parallelInsert)) {
return failure();
}

int64_t argId = destBbArg.getArgNumber() - numInductionVars;
terminators[argId] = parallelInsert;
if (destSlice) {
pairedSlices[argId] = destSlice;
}

// Collect all of the offset/size/stride operands for both slices and
// compute a backwards slice of the program from them. Fail if any of
Expand All @@ -1148,10 +1183,12 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
parallelInsert.getOperands().begin() +
parallelInsert.getOffsetSizeAndStrideStartOperandIndex(),
parallelInsert.getOperands().end());
sliceOperands.insert(
destSlice.getOperands().begin() +
destSlice.getOffsetSizeAndStrideStartOperandIndex(),
destSlice.getOperands().end());
if (destSlice) {
sliceOperands.insert(
destSlice.getOperands().begin() +
destSlice.getOffsetSizeAndStrideStartOperandIndex(),
destSlice.getOperands().end());
}
for (Value operand : sliceOperands) {
if (auto bbArg = dyn_cast<BlockArgument>(operand)) {
if (bbArg.getOwner()->getParentOp() == loop) {
Expand Down Expand Up @@ -1200,8 +1237,15 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(newForallOp.getTerminator());
SmallVector<Value> newInits;
for (auto slice : pairedSlices) {
newInits.push_back(slice.getResult());
for (auto [iterArgId, slice] : llvm::enumerate(pairedSlices)) {
if (slice) {
newInits.push_back(slice.value().getResult());
continue;
}

// If there is no paired slice (for a single trip count loop) then
// use the iter arg of the forall op directly.
newInits.push_back(newForallOp.getRegionIterArgs()[iterArgId]);
}
// Step 4. Create a new for loop with new inits for the result of the
// extracted slices.
Expand All @@ -1224,7 +1268,10 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
// args.
for (auto [hoistedSlice, iterArg] :
llvm::zip_equal(pairedSlices, newLoop.getRegionIterArgs())) {
rewriter.replaceAllUsesExcept(hoistedSlice, iterArg, newLoop);
if (hoistedSlice) {
rewriter.replaceAllUsesExcept(hoistedSlice.value(), iterArg,
newLoop);
}
}

// Create the terminator for the new loop using the sources of the
Expand All @@ -1243,7 +1290,9 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
rewriter.moveOpBefore(sliceOperandProducer, newLoop);
}
for (auto slice : pairedSlices) {
rewriter.moveOpBefore(slice, newLoop);
if (slice) {
rewriter.moveOpBefore(slice.value(), newLoop);
}
}

// Create the new terminator for the hoisted forall loop using the results
Expand Down

0 comments on commit 6c9aad0

Please sign in to comment.