Skip to content

Commit

Permalink
Add pattern to undo in/out perm
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Feb 6, 2025
1 parent 9d5cb74 commit e7473ff
Showing 1 changed file with 52 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -83,6 +85,55 @@ struct FoldFullInsertSlice : public OpRewritePattern<tensor::InsertSliceOp> {
}
};

//===----------------------------------------------------------------------===//
// ElementwiseOpInterchangePattern
//===----------------------------------------------------------------------===//

// If possible, interchange indexing maps to make input maps to remove
// permutations.
struct ElementwiseOpInterchangePattern final
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 ||
genericOp.getNumDpsInputs() == 0)
return failure();

// Find an input map that is a non-identity permutation that matches the
// output map (also not identity).
AffineMap inputMap;
auto *initOperand = genericOp.getDpsInitOperand(0);
AffineMap outputMap = genericOp.getMatchingIndexingMap(initOperand);
auto inputMaps = genericOp.getIndexingMapsArray();
for (auto candidateInputMap : ArrayRef(inputMaps).drop_back()) {
if (!candidateInputMap.isIdentity() &&
candidateInputMap.isPermutation() && candidateInputMap == outputMap) {
inputMap = candidateInputMap;
break;
}
}
if (!inputMap) {
return failure();
}

ArrayRef<AffineExpr> exprs = inputMap.getResults();
auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});

// Don't let this mess up other maps.
for (auto map : genericOp.getIndexingMapsArray()) {
AffineMap composed = map.compose(inversePermutation(inputMap));
if (inputMap != map && !compressUnusedDims(composed).isIdentity()) {
return failure();
}
}

return linalg::interchangeGenericOp(rewriter, genericOp, perm);
}
};

/// Canonicalize operations in nested regions.
struct CanonicalizerPass
: public impl::CanonicalizerPassBase<CanonicalizerPass> {
Expand All @@ -105,6 +156,7 @@ struct CanonicalizerPass
// compilation phase.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);
owningPatterns.add<FoldFullInsertSlice>(context);
owningPatterns.add<ElementwiseOpInterchangePattern>(context);

patterns =
std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
Expand Down

0 comments on commit e7473ff

Please sign in to comment.