Skip to content

[mlir][Vector] Remove usage of vector.insertelement/extractelement from Vector #144413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ void populateBreakDownVectorReductionPatterns(
///
/// [DecomposeNDExtractStridedSlice]
/// ================================
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
/// For such cases, we can rewrite it to ExtractOp + lower rank
/// ExtractStridedSliceOp + InsertOp for the n-D case.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

Expand Down
13 changes: 5 additions & 8 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();

Location loc = xferOp.getLoc();
return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
}

/// Helper function TransferOpConversion and TransferOp1dConversion.
Expand Down Expand Up @@ -760,8 +760,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {

if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
// vector.extractelement for 1D vectors).
// non-constant value.
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
Expand Down Expand Up @@ -824,8 +823,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
}

// Print the scalar elements in the inner most loop.
auto element =
rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);

Expand Down Expand Up @@ -1567,7 +1565,7 @@ struct Strategy1d<TransferReadOp> {
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
return b.create<vector::InsertElementOp>(loc, val, vec, iv);
return b.create<vector::InsertOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
Expand Down Expand Up @@ -1595,8 +1593,7 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
auto val =
b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
});
b.create<scf::YieldOp>(loc);
Expand Down
51 changes: 5 additions & 46 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,27 +1255,6 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};

/// Pattern to convert vector.extractelement to vector.extract.
struct WarpOpExtractElement : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
if (!operand)
return failure();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
SmallVector<OpFoldResult> indices;
if (auto pos = extractOp.getPosition()) {
indices.push_back(pos);
}
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, extractOp.getVector(), indices);
return success();
}
};

/// Pattern to move out vector.insert with a scalar input.
/// Only supports 1-D and 0-D destinations for now.
struct WarpOpInsertScalar : public WarpDistributionPattern {
Expand Down Expand Up @@ -1483,26 +1462,6 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};

struct WarpOpInsertElement : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
if (!operand)
return failure();
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
SmallVector<OpFoldResult> indices;
if (auto pos = insertOp.getPosition()) {
indices.push_back(pos);
}
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertOp>(
insertOp, insertOp.getSource(), insertOp.getDest(), indices);
return success();
}
};

/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
Expand Down Expand Up @@ -1761,11 +1720,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
patterns.getContext(), benefit);
patterns
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
Expand Down
119 changes: 25 additions & 94 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,23 +767,26 @@ class FlattenContiguousRowMajorTransferWritePattern
unsigned targetVectorBitwidth;
};

/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
/// to `memref.load` patterns. The `match` method is shared for both
/// `vector.extract` and `vector.extract_element`.
template <class VectorExtractOp>
class RewriteScalarExtractOfTransferReadBase
: public OpRewritePattern<VectorExtractOp> {
using Base = OpRewritePattern<VectorExtractOp>;

/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
///
/// All the users of the transfer op must be `vector.extract` ops. If
/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
/// users. Otherwise, rewrite only if the extract op is the single user of the
/// transfer op. Rewriting a single vector load with multiple scalar loads may
/// negatively affect performance.
class RewriteScalarExtractOfTransferRead
: public OpRewritePattern<vector::ExtractOp> {
public:
RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
PatternBenefit benefit,
bool allowMultipleUses)
: Base(context, benefit), allowMultipleUses(allowMultipleUses) {}

LogicalResult match(VectorExtractOp extractOp) const {
auto xferOp =
extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
RewriteScalarExtractOfTransferRead(MLIRContext *context,
PatternBenefit benefit,
bool allowMultipleUses)
: OpRewritePattern(context, benefit),
allowMultipleUses(allowMultipleUses) {}

LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Match phase.
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
return failure();
// Check that we are extracting a scalar and not a sub-vector.
Expand All @@ -795,8 +798,7 @@ class RewriteScalarExtractOfTransferReadBase
// If multiple uses are allowed, check if all the xfer uses are extract ops.
if (allowMultipleUses &&
!llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
return isa<vector::ExtractOp, vector::ExtractElementOp>(
use.getOwner());
return isa<vector::ExtractOp>(use.getOwner());
}))
return failure();
// Mask not supported.
Expand All @@ -808,81 +810,8 @@ class RewriteScalarExtractOfTransferReadBase
// Cannot rewrite if the indices may be out of bounds.
if (xferOp.hasOutOfBoundsDim())
return failure();
return success();
}

private:
bool allowMultipleUses;
};

/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
///
/// All the users of the transfer op must be either `vector.extractelement` or
/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
/// transfer ops with any number of users. Otherwise, rewrite only if the
/// extract op is the single user of the transfer op. Rewriting a single
/// vector load with multiple scalar loads may negatively affect performance.
class RewriteScalarExtractElementOfTransferRead
: public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;

LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
PatternRewriter &rewriter) const override {
if (failed(match(extractOp)))
return failure();

// Construct scalar load.
auto loc = extractOp.getLoc();
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
if (extractOp.getPosition()) {
AffineExpr sym0, sym1;
bindSymbols(extractOp.getContext(), sym0, sym1);
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, sym0 + sym1,
{newIndices[newIndices.size() - 1], extractOp.getPosition()});
if (auto value = dyn_cast<Value>(ofr)) {
newIndices[newIndices.size() - 1] = value;
} else {
newIndices[newIndices.size() - 1] =
rewriter.create<arith::ConstantIndexOp>(loc,
*getConstantIntValue(ofr));
}
}
if (isa<MemRefType>(xferOp.getBase().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
newIndices);
} else {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getBase(), newIndices);
}

return success();
}
};

/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
///
/// All the users of the transfer op must be either `vector.extractelement` or
/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
/// transfer ops with any number of users. Otherwise, rewrite only if the
/// extract op is the single user of the transfer op. Rewriting a single
/// vector load with multiple scalar loads may negatively affect performance.
class RewriteScalarExtractOfTransferRead
: public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;

LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
if (failed(match(extractOp)))
return failure();

// Construct scalar load.
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
// Rewrite phase: construct scalar load.
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
Expand Down Expand Up @@ -923,6 +852,9 @@ class RewriteScalarExtractOfTransferRead

return success();
}

private:
bool allowMultipleUses;
};

/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
Expand Down Expand Up @@ -979,8 +911,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
bool allowMultipleUses) {
patterns.add<RewriteScalarExtractElementOfTransferRead,
RewriteScalarExtractOfTransferRead>(patterns.getContext(),
patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
benefit, allowMultipleUses);
patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
}
Expand Down
14 changes: 7 additions & 7 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func.func @materialize_read_1d() {
// Both accesses in the load must be clipped otherwise %i1 + 2 and %i1 + 3 will go out of bounds.
// CHECK: scf.if
// CHECK-NEXT: memref.load
// CHECK-NEXT: vector.insertelement
// CHECK-NEXT: vector.insert
// CHECK-NEXT: scf.yield
// CHECK-NEXT: else
// CHECK-NEXT: scf.yield
Expand Down Expand Up @@ -103,7 +103,7 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK: %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
// CHECK: scf.if {{.*}} -> (vector<3xf32>) {
// CHECK-NEXT: %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32>
// CHECK-NEXT: %[[RVEC:.*]] = vector.insert %[[SCAL]], %{{.*}} [%[[I6]]] : f32 into vector<3xf32>
// CHECK-NEXT: scf.yield
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield
Expand Down Expand Up @@ -540,9 +540,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
// CHECK: scf.for %[[IDX:.*]] = %[[C_0]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[MASK_VAL:.*]] = vector.extractelement %[[MASK_VEC]][%[[IDX]] : index] : vector<[16]xi1>
// CHECK: %[[MASK_VAL:.*]] = vector.extract %[[MASK_VEC]][%[[IDX]]] : i1 from vector<[16]xi1>
// CHECK: scf.if %[[MASK_VAL]] {
// CHECK: %[[VAL_TO_STORE:.*]] = vector.extractelement %{{.*}}[%[[IDX]] : index] : vector<[16]xf32>
// CHECK: %[[VAL_TO_STORE:.*]] = vector.extract %{{.*}}[%[[IDX]]] : f32 from vector<[16]xf32>
// CHECK: memref.store %[[VAL_TO_STORE]], %[[ARG_0]][%[[IDX]]] : memref<?xf32, strided<[?], offset: ?>>
// CHECK: } else {
// CHECK: }
Expand All @@ -561,7 +561,7 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
// CHECK: %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[IDX]] : index] : vector<1xf32>
// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {
Expand Down Expand Up @@ -591,7 +591,7 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: %[[OUTER_INDEX:.*]] = arith.muli %[[I]], %[[C2]] : index
// CHECK: %[[FLAT_INDEX:.*]] = arith.addi %[[J]], %[[OUTER_INDEX]] : index
// CHECK: %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[FLAT_INDEX]] : index] : vector<4xf32>
// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[FLAT_INDEX]]] : f32 from vector<4xf32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST_J:.*]] = arith.cmpi ult, %[[J]], %[[C1]] : index
// CHECK: scf.if %[[IS_NOT_LAST_J]] {
Expand Down Expand Up @@ -625,7 +625,7 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
// CHECK: %[[LAST_INDEX:.*]] = arith.subi %[[UPPER_BOUND]], %[[C1]] : index
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[UPPER_BOUND]] step %[[C1]] {
// CHECK: %[[EL:.*]] = vector.extractelement %[[VEC]]{{\[}}%[[IDX]] : index] : vector<[4]xi32>
// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][%[[IDX]]] : i32 from vector<[4]xi32>
// CHECK: vector.print %[[EL]] : i32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[LAST_INDEX]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
%cst = arith.constant 0.0 : f32
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
%1 = vector.extractelement %0[] : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>
return %1 : f32
}

Expand All @@ -24,7 +24,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
%cst = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
%1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
%1 = vector.extract %0[%idx2] : f32 from vector<5xf32>
return %1 : f32
}

Expand All @@ -37,7 +37,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
%cst = arith.constant 0.0 : f32
%0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
%1 = vector.extractelement %0[] : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>
return %1 : f32
}

Expand Down
Loading
Loading