-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][Vector] Remove usage of vector.insertelement/extractelement
from Vector
#144413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][Vector] Remove usage of vector.insertelement/extractelement
from Vector
#144413
Conversation
…from Vector This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops. RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops It removes instances of `vector.extractelement` and `vector.insertelement` from the Vector dialect layer.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR is part of the last step to remove It removes instances of Patch is 29.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144413.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..ec0f856cb3f5a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -233,8 +233,8 @@ void populateBreakDownVectorReductionPatterns(
///
/// [DecomposeNDExtractStridedSlice]
/// ================================
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..45059f19a95c4 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+ return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -760,8 +760,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
- // non-constant value (which can currently only be done via
- // vector.extractelement for 1D vectors).
+ // non-constant value.
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
@@ -824,8 +823,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
}
// Print the scalar elements in the inner most loop.
- auto element =
- rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+ auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);
@@ -1567,7 +1565,7 @@ struct Strategy1d<TransferReadOp> {
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
- return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+ return b.create<vector::InsertOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1595,8 +1593,7 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
- auto val =
- b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+ auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
});
b.create<scf::YieldOp>(loc);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..90970ae53defc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1255,27 +1255,6 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
-/// Pattern to convert vector.extractelement to vector.extract.
-struct WarpOpExtractElement : public WarpDistributionPattern {
- using Base::Base;
- LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
- PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
- if (!operand)
- return failure();
- auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
- SmallVector<OpFoldResult> indices;
- if (auto pos = extractOp.getPosition()) {
- indices.push_back(pos);
- }
- rewriter.setInsertionPoint(extractOp);
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, extractOp.getVector(), indices);
- return success();
- }
-};
-
/// Pattern to move out vector.insert with a scalar input.
/// Only supports 1-D and 0-D destinations for now.
struct WarpOpInsertScalar : public WarpDistributionPattern {
@@ -1483,26 +1462,6 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};
-struct WarpOpInsertElement : public WarpDistributionPattern {
- using Base::Base;
- LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
- PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
- if (!operand)
- return failure();
- auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
- SmallVector<OpFoldResult> indices;
- if (auto pos = insertOp.getPosition()) {
- indices.push_back(pos);
- }
- rewriter.setInsertionPoint(insertOp);
- rewriter.replaceOpWithNewOp<vector::InsertOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(), indices);
- return success();
- }
-};
-
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
@@ -1761,11 +1720,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 384717aeca665..62e7f7cc61f6c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -767,23 +767,26 @@ class FlattenContiguousRowMajorTransferWritePattern
unsigned targetVectorBitwidth;
};
-/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
-/// to `memref.load` patterns. The `match` method is shared for both
-/// `vector.extract` and `vector.extract_element`.
-template <class VectorExtractOp>
-class RewriteScalarExtractOfTransferReadBase
- : public OpRewritePattern<VectorExtractOp> {
- using Base = OpRewritePattern<VectorExtractOp>;
-
+/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
+///
+/// All the users of the transfer op must be `vector.extract` ops. If
+/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
+/// users. Otherwise, rewrite only if the extract op is the single user of the
+/// transfer op. Rewriting a single vector load with multiple scalar loads may
+/// negatively affect performance.
+class RewriteScalarExtractOfTransferRead
+ : public OpRewritePattern<vector::ExtractOp> {
public:
- RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
- PatternBenefit benefit,
- bool allowMultipleUses)
- : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
-
- LogicalResult match(VectorExtractOp extractOp) const {
- auto xferOp =
- extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
+ RewriteScalarExtractOfTransferRead(MLIRContext *context,
+ PatternBenefit benefit,
+ bool allowMultipleUses)
+ : OpRewritePattern(context, benefit),
+ allowMultipleUses(allowMultipleUses) {}
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Match phase.
+ auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
return failure();
// Check that we are extracting a scalar and not a sub-vector.
@@ -795,8 +798,7 @@ class RewriteScalarExtractOfTransferReadBase
// If multiple uses are allowed, check if all the xfer uses are extract ops.
if (allowMultipleUses &&
!llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
- return isa<vector::ExtractOp, vector::ExtractElementOp>(
- use.getOwner());
+ return isa<vector::ExtractOp>(use.getOwner());
}))
return failure();
// Mask not supported.
@@ -808,81 +810,8 @@ class RewriteScalarExtractOfTransferReadBase
// Cannot rewrite if the indices may be out of bounds.
if (xferOp.hasOutOfBoundsDim())
return failure();
- return success();
- }
-
-private:
- bool allowMultipleUses;
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractElementOfTransferRead
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
- using RewriteScalarExtractOfTransferReadBase::
- RewriteScalarExtractOfTransferReadBase;
-
- LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
- PatternRewriter &rewriter) const override {
- if (failed(match(extractOp)))
- return failure();
-
- // Construct scalar load.
- auto loc = extractOp.getLoc();
- auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
- SmallVector<Value> newIndices(xferOp.getIndices().begin(),
- xferOp.getIndices().end());
- if (extractOp.getPosition()) {
- AffineExpr sym0, sym1;
- bindSymbols(extractOp.getContext(), sym0, sym1);
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, sym0 + sym1,
- {newIndices[newIndices.size() - 1], extractOp.getPosition()});
- if (auto value = dyn_cast<Value>(ofr)) {
- newIndices[newIndices.size() - 1] = value;
- } else {
- newIndices[newIndices.size() - 1] =
- rewriter.create<arith::ConstantIndexOp>(loc,
- *getConstantIntValue(ofr));
- }
- }
- if (isa<MemRefType>(xferOp.getBase().getType())) {
- rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
- newIndices);
- } else {
- rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
- extractOp, xferOp.getBase(), newIndices);
- }
-
- return success();
- }
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractOfTransferRead
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
- using RewriteScalarExtractOfTransferReadBase::
- RewriteScalarExtractOfTransferReadBase;
-
- LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- if (failed(match(extractOp)))
- return failure();
- // Construct scalar load.
- auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+ // Rewrite phase: construct scalar load.
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
@@ -923,6 +852,9 @@ class RewriteScalarExtractOfTransferRead
return success();
}
+
+private:
+ bool allowMultipleUses;
};
/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
@@ -979,8 +911,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
bool allowMultipleUses) {
- patterns.add<RewriteScalarExtractElementOfTransferRead,
- RewriteScalarExtractOfTransferRead>(patterns.getContext(),
+ patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
benefit, allowMultipleUses);
patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 5a6da3a06387a..33177736eb5fe 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -37,7 +37,7 @@ func.func @materialize_read_1d() {
// Both accesses in the load must be clipped otherwise %i1 + 2 and %i1 + 3 will go out of bounds.
// CHECK: scf.if
// CHECK-NEXT: memref.load
- // CHECK-NEXT: vector.insertelement
+ // CHECK-NEXT: vector.insert
// CHECK-NEXT: scf.yield
// CHECK-NEXT: else
// CHECK-NEXT: scf.yield
@@ -103,7 +103,7 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK: %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
// CHECK: scf.if {{.*}} -> (vector<3xf32>) {
// CHECK-NEXT: %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref<?x?x?x?xf32>
- // CHECK-NEXT: %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32>
+ // CHECK-NEXT: %[[RVEC:.*]] = vector.insert %[[SCAL]], %{{.*}} [%[[I6]]] : f32 into vector<3xf32>
// CHECK-NEXT: scf.yield
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield
@@ -540,9 +540,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
// CHECK: scf.for %[[IDX:.*]] = %[[C_0]] to %[[UB]] step %[[STEP]] {
-// CHECK: %[[MASK_VAL:.*]] = vector.extractelement %[[MASK_VEC]][%[[IDX]] : index] : vector<[16]xi1>
+// CHECK: %[[MASK_VAL:.*]] = vector.extract %[[MASK_VEC]][%[[IDX]]] : i1 from vector<[16]xi1>
// CHECK: scf.if %[[MASK_VAL]] {
-// CHECK: %[[VAL_TO_STORE:.*]] = vector.extractelement %{{.*}}[%[[IDX]] : index] : vector<[16]xf32>
+// CHECK: %[[VAL_TO_STORE:.*]] = vector.extract %{{.*}}[%[[IDX]]] : f32 from vector<[16]xf32>
// CHECK: memref.store %[[VAL_TO_STORE]], %[[ARG_0]][%[[IDX]]] : memref<?xf32, strided<[?], offset: ?>>
// CHECK: } else {
// CHECK: }
@@ -561,7 +561,7 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK: %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[IDX]] : index] : vector<1xf32>
+// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {
@@ -591,7 +591,7 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: %[[OUTER_INDEX:.*]] = arith.muli %[[I]], %[[C2]] : index
// CHECK: %[[FLAT_INDEX:.*]] = arith.addi %[[J]], %[[OUTER_INDEX]] : index
-// CHECK: %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[FLAT_INDEX]] : index] : vector<4xf32>
+// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[FLAT_INDEX]]] : f32 from vector<4xf32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST_J:.*]] = arith.cmpi ult, %[[J]], %[[C1]] : index
// CHECK: scf.if %[[IS_NOT_LAST_J]] {
@@ -625,7 +625,7 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
// CHECK: %[[LAST_INDEX:.*]] = arith.subi %[[UPPER_BOUND]], %[[C1]] : index
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[UPPER_BOUND]] step %[[C1]] {
-// CHECK: %[[EL:.*]] = vector.extractelement %[[VEC]]{{\[}}%[[IDX]] : index] : vector<[4]xi32>
+// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][%[[IDX]]] : i32 from vector<[4]xi32>
// CHECK: vector.print %[[EL]] : i32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[LAST_INDEX]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 7a1d6b3a8344a..7fec1c6ba5642 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -8,7 +8,7 @@
func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
%cst = arith.constant 0.0 : f32
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
- %1 = vector.extractelement %0[] : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
return %1 : f32
}
@@ -24,7 +24,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
%cst = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
- %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
+ %1 = vector.extract %0[%idx2] : f32 from vector<5xf32>
return %1 : f32
}
@@ -37,7 +37,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
%cst = arith.constant 0.0 : f32
%0 = vector.transfer_read %t[%idx, %idx, %idx], %...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is long over-due, thank you!
LGTM % test-name fixes :)
@@ -1142,7 +1142,7 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, % | |||
|
|||
// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)> | |||
// CHECK-PROP: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 3)> | |||
// CHECK-PROP-LABEL: func @vector_insertelement_1d( | |||
// CHECK-PROP-LABEL: func @_1d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIXME :)
@@ -1155,48 +1155,48 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, % | |||
// CHECK-PROP: scf.yield %[[W]]#0 | |||
// CHECK-PROP: } | |||
// CHECK-PROP: return %[[R]] | |||
func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32>) { | |||
func.func @_1d(%laneid: index, %pos: index) -> (vector<3xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIXME :)
gpu.yield %1 : vector<96xf32> | ||
} | ||
return %r : vector<3xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-PROP-LABEL: func @vector_insertelement_1d_broadcast( | ||
// CHECK-PROP-LABEL: func @_1d_broadcast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIXME :)
gpu.yield %1 : vector<96xf32> | ||
} | ||
return %r : vector<96xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-PROP-LABEL: func @vector_insertelement_0d( | ||
// CHECK-PROP-LABEL: func @_0d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIXME :)
This PR is part of the last step to remove
vector.extractelement
andvector.insertelement
ops.RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops
It removes instances of
vector.extractelement
andvector.insertelement
from the Vector dialect layer.