-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Introduce FoldTensorCastUnPackOp
#121393
[mlir][tensor] Introduce FoldTensorCastUnPackOp
#121393
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch specializes Currently, %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> as: %unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> This leads to an Op verification failure because the folder does not Additional Changes:
Full diff: https://github.com/llvm/llvm-project/pull/121393.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..ee9a4012a01393 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
// Already a constant
newMixedTileSizes.push_back(std::get<1>(it));
} else {
- int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
- assert(tileSize == shape && "tile size and dim size don't match!");
- (void)tileSize;
+ assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+ "tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
}
}
// Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
}
};
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(UnPackOp op,
+ PatternRewriter &rewriter) const override {
+ if (!foldTensorCastPrecondition(op))
+ return failure();
+
+ SmallVector<Type> newResultTypes(op->getResultTypes());
+ SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ Value sourceTensor = newOperands[0];
+
+ // Get the updated mixed-tile-sizes attribute.
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
+ .getShape()
+ .take_back(op.getMixedTiles().size()),
+ op.getMixedTiles())) {
+ int64_t shape = std::get<0>(it);
+ // If the current source shape is dynamic, just preserve this mixed
+ // size.
+ if (shape == ShapedType::kDynamic) {
+ newMixedTileSizes.push_back(std::get<1>(it));
+ continue;
+ }
+
+ // If the current source is static, update the dynamic mixed-size
+ // (provided the original value is dynamic).
+ if (Attribute attr =
+ llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+ // Already a constant
+ newMixedTileSizes.push_back(std::get<1>(it));
+ } else {
+ assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+ "tile size and dim size don't match!");
+ newMixedTileSizes.push_back(
+ (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ }
+ }
+
+ // Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
+ UnPackOp newOp = rewriter.create<UnPackOp>(
+ op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getOuterDimsPerm());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+ // Replace op.
+ Value oldResult = op.getResult();
+ Value newResult = newOp.getResult();
+ Value replacement = (newResult.getType() != oldResult.getType())
+ ? rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult)
+ : newResult;
+
+ rewriter.replaceOp(op, {replacement});
+
+ return success();
+ }
+};
+
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
@@ -4890,7 +4969,7 @@ struct FoldTensorCastProducerOp
PatternRewriter &rewriter) const override {
// Reject tensor::PackOp - there's dedicated pattern for that instead.
- if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+ if (!foldTensorCastPrecondition(op) || isa<tensor::PackOp, tensor::UnPackOp>(*op))
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +5002,7 @@ struct FoldTensorCastProducerOp
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
+ results.add<FoldTensorCastUnPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e8fc4ce834e18f..88e3691e2d6297 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}
+
// -----
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
@@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size(
// -----
+// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK: return %[[RES]] : tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+ %src: tensor<1x1x8x1xi32>,
+ %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+ %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+ %c8 = arith.constant 8 : index
+ %unpack = tensor.unpack %cast
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%c8, 1]
+ into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+ return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
// CHECK: tensor.pack {{.*}} {test_attr}
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab247..1de3e281bc462b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
// -----
-func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This change mirrors a similar update made for `tensor::PackOp` in llvm#114559. Below is the updated rationale for `tensor::UnPackOp`. Currently, `FoldTensorCastProducerOp` incorrectly folds the following: ```mlir %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> ``` as: ```mlir %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> ``` This leads to an Op verification failure because the folder does not update the inner tile sizes in the unpack Op. This patch resolves the issue. Additional Changes: * invalid.mlir: Fixes a typo. * TensorOps.cpp: Removes unnecessary `(void)tileSize` and adds extra comments following this discussion: llvm#115772.
1dcc023
to
1bc2d8e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just some nits. Thanks for all the refactoring so far!
if (Attribute attr = | ||
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can you add a local variable (similar to what you did above with shape
) for the second iterator's value (e.g., something like tile
)? I think it makes it more clear what the iterator is when reading the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! And to keep things consistent, let me update FoldTensorCastPackOp
as well.
struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> { | ||
using OpRewritePattern<UnPackOp>::OpRewritePattern; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Most of the logic in this function is the same as for tensor.pack, but with the source type instead of the dest type. Could you refactor the logic a bit to try to share code from a single function (mainly for finding the new mixed tile sizes)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point, sending update shortly. Thanks for the suggestion!
%unpack = tensor.unpack %cast | ||
inner_dims_pos = [0, 1] | ||
inner_tiles = [%c8, 1] | ||
into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Maybe use the same style as the test below for the attribute? (i.e., {test_attr}
instead of {some_attr}
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've actually followed the test above
llvm-project/mlir/test/Dialect/Tensor/canonicalize.mlir
Lines 2791 to 2813 in 45e874e
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size | |
// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>, | |
// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>, | |
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { | |
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) | |
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] | |
// CHECK-SAME: some_attr | |
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32> | |
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> | |
func.func @fold_cast_pack_dynamic_tile_size( | |
%dest: tensor<1x1x8x1xi32>, | |
%src: tensor<7x?xi32>, | |
%pad: i32) -> tensor<1x1x8x1xi32> { | |
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> | |
%c8 = arith.constant 8 : index | |
%pack = tensor.pack %src padding_value(%pad : i32) | |
inner_dims_pos = [0, 1] | |
inner_tiles = [%c8, 1] | |
into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> | |
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> | |
return %res : tensor<1x1x8x1xi32> | |
} |
😂 Let me unify this.
Address PR comments
Adds an end-to-end test for `tensor.unpack` with dynamic inner tile sizes. While relatively simple (e.g., no vectorization), this example required a few fixes in handling `tensor.unpack` (and similar fixes for `tensor.pack` before that): * llvm#119379, llvm#121393, llvm#121400. The end goal for this test is to incrementally increase its complexity and to work towards scalable tile sizes. Note, this PR complements llvm#115698 in which similar test for `tensor.pack` was added.
This patch specializes
FoldTensorCastProducerOp
fortensor::UnPackOp
byintroducing a dedicated pattern:
FoldTensorCastUnPackOp
. This mirrors asimilar update made for
tensor::PackOp
in #114559. Below is the updatedrationale tailored to
tensor::UnPackOp
.ISSUE DESCRIPTION
Currently,
FoldTensorCastProducerOp
incorrectly folds the following:as:
This triggers an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch addresses the
issue by ensuring proper handling of inner tile sizes.
ADDITIONAL CHANGES
(void)tileSize
.FoldTensorCastPackOp
for consistency withthe newly introduced
FoldTensorCastUnPackOp
.test_attr
(e.g.,replaced mixed use of
test_attr
andsome_attr
).