-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] [Vector] Linearization patterns for vector.load and vector.store #145115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@newling following up on #143420 (comment) For 2, |
Yeah, this is what I meant 👍 Ideally we'll eventually have something like flattening of transfer_read, done here. i.e. linearize even when there is more than 1 dimension of size > 1, and it needn't be the inner-most dim. But I guess that can wait. FWIW IMO that transfer_read code should be in VectorLinearize too, I mentioned that at the bottom of this comment. And the vector.load linearization code could probably then reuse some of it. Something for the future, maybe! |
@newling can you review it as well? |
@llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145115.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
}
};
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = loadOp.getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+ loadOp.getLoc(), vecTy, newLoad.getResult());
+ rewriter.replaceOp(loadOp, shapeCast.getResult());
+ return success();
+ }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.store %arg0, %arg1[%c0, %c0]
+/// : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// vector.store %arg0, %arg1[%c0, %%c0]
+/// : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = storeOp.getValueToStore().getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+
+ Value valueToStore = adaptor.getValueToStore();
+ if (valueToStore.getType() != linearTy) {
+ valueToStore = rewriter.create<vector::ShapeCastOp>(
+ storeOp.getLoc(), linearTy, valueToStore);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+ LinearizeVectorStore>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[CAST]] : vector<1x4xf32>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return
+}
|
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145115.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
}
};
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = loadOp.getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+ loadOp.getLoc(), vecTy, newLoad.getResult());
+ rewriter.replaceOp(loadOp, shapeCast.getResult());
+ return success();
+ }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.store %arg0, %arg1[%c0, %c0]
+/// : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// vector.store %arg0, %arg1[%c0, %%c0]
+/// : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = storeOp.getValueToStore().getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+
+ Value valueToStore = adaptor.getValueToStore();
+ if (valueToStore.getType() != linearTy) {
+ valueToStore = rewriter.create<vector::ShapeCastOp>(
+ storeOp.getLoc(), linearTy, valueToStore);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+ LinearizeVectorStore>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[CAST]] : vector<1x4xf32>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I added 2 small comments about simplifying and generalizing a bit.
Addressed the feedback, thanks :) |
This PR add inearizarion pattern for vector.load and vector.store. It is follow up PR to #143420 (comment)