Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Introduce FoldTensorCastUnPackOp #121393

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Dec 31, 2024

This patch specializes FoldTensorCastProducerOp for tensor::UnPackOp by
introducing a dedicated pattern: FoldTensorCastUnPackOp. This mirrors a
similar update made for tensor::PackOp in #114559. Below is the updated
rationale tailored to tensor::UnPackOp.

ISSUE DESCRIPTION

Currently, FoldTensorCastProducerOp incorrectly folds the following:

%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
// Note: `%c8` and `?`.
%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>

as:

// Note: `%c8` and `8`.
%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor<1x1x8x1xi32> -> tensor<7x?xi32>

This triggers an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch addresses the
issue by ensuring proper handling of inner tile sizes.

ADDITIONAL CHANGES

  • invalid.mlir: Fixed a typo.
  • TensorOps.cpp:
  • Tensor/canonicalize.mlir: Ensured consistent usage of test_attr (e.g.,
    replaced mixed use of test_attr and some_attr).

@llvmbot
Copy link
Member

llvmbot commented Dec 31, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch specializes FoldTensorCastProducerOp for tensor::UnPackOp by
introducing a dedicated pattern: FoldTensorCastUnPackOp. This change
mirrors a similar update made for tensor::PackOp in #114559. Below is
the updated rationale for tensor::UnPackOp.

Currently, FoldTensorCastProducerOp incorrectly folds the following:

%cast = tensor.cast %dest : tensor&lt;1x1x8x1xi32&gt; to tensor&lt;1x1x?x1xi32&gt;
%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor&lt;1x1x?x1xi32&gt; -&gt; tensor&lt;7x?xi32&gt;

as:

%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor&lt;1x1x?x1xi32&gt; -&gt; tensor&lt;7x?xi32&gt;

This leads to an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch resolves the
issue.

Additional Changes:


Full diff: https://github.com/llvm/llvm-project/pull/121393.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+84-4)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+21)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+1-1)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..ee9a4012a01393 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         // Already a constant
         newMixedTileSizes.push_back(std::get<1>(it));
       } else {
-        int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
-        assert(tileSize == shape && "tile size and dim size don't match!");
-        (void)tileSize;
+        assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+               "tile size and dim size don't match!");
         newMixedTileSizes.push_back(
             (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
       }
     }
 
     // Clone op.
+    // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+    // this point. However, in practice, we use them for things that we'd like
+    // to preserve. Implement a better abstraction.
     PackOp newOp = rewriter.create<PackOp>(
         op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
         newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
   }
 };
 
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+///   %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = tensor.unpack %0  ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(UnPackOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!foldTensorCastPrecondition(op))
+      return failure();
+
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    Value sourceTensor = newOperands[0];
+
+    // Get the updated mixed-tile-sizes attribute.
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
+                                 .getShape()
+                                 .take_back(op.getMixedTiles().size()),
+                             op.getMixedTiles())) {
+      int64_t shape = std::get<0>(it);
+      // If the current source shape is dynamic, just preserve this mixed
+      // size.
+      if (shape == ShapedType::kDynamic) {
+        newMixedTileSizes.push_back(std::get<1>(it));
+        continue;
+      }
+
+      // If the current source is static, update the dynamic mixed-size
+      // (provided the original value is dynamic).
+      if (Attribute attr =
+              llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+        // Already a constant
+        newMixedTileSizes.push_back(std::get<1>(it));
+      } else {
+        assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
+               "tile size and dim size don't match!");
+        newMixedTileSizes.push_back(
+            (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+      }
+    }
+
+    // Clone op.
+    // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+    // this point. However, in practice, we use them for things that we'd like
+    // to preserve. Implement a better abstraction.
+    UnPackOp newOp = rewriter.create<UnPackOp>(
+        op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+        newMixedTileSizes,  op.getOuterDimsPerm());
+    newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+    // Replace op.
+    Value oldResult = op.getResult();
+    Value newResult = newOp.getResult();
+    Value replacement = (newResult.getType() != oldResult.getType())
+                            ? rewriter.create<tensor::CastOp>(
+                                  op->getLoc(), oldResult.getType(), newResult)
+                            : newResult;
+
+    rewriter.replaceOp(op, {replacement});
+
+    return success();
+  }
+};
+
 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
 /// the `tensor.cast` has source that is more static than the consuming op.
 ///
@@ -4890,7 +4969,7 @@ struct FoldTensorCastProducerOp
                                 PatternRewriter &rewriter) const override {
 
     // Reject tensor::PackOp - there's dedicated pattern for that instead.
-    if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+    if (!foldTensorCastPrecondition(op) || isa<tensor::PackOp, tensor::UnPackOp>(*op))
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +5002,7 @@ struct FoldTensorCastProducerOp
 void TensorDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
   results.add<FoldTensorCastPackOp>(getContext());
+  results.add<FoldTensorCastUnPackOp>(getContext());
   results.add<FoldTensorCastProducerOp>(getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e8fc4ce834e18f..88e3691e2d6297 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
   return %0#1 : index
 }
+
 // -----
 
 // CHECK-LABEL:   func.func @fold_cast_pack_dynamic_tile_size
@@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size(
 
 // -----
 
+// CHECK-LABEL:   func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME:      %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK:           %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+// CHECK:           return %[[RES]] : tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+  %src: tensor<1x1x8x1xi32>,
+  %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+    %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+    %c8 = arith.constant 8 : index
+    %unpack = tensor.unpack %cast
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%c8, 1]
+      into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+    return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @pack_dont_drop_attributes(
 // CHECK: tensor.pack {{.*}}  {test_attr}
 func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab247..1de3e281bc462b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
 
 // -----
 
-func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
   // expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
   %0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
   return %0 : tensor<256x128xf32>

Copy link

github-actions bot commented Dec 31, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by
introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This change
mirrors a similar update made for `tensor::PackOp` in llvm#114559. Below is
the updated rationale for `tensor::UnPackOp`.

Currently, `FoldTensorCastProducerOp` incorrectly folds the following:

```mlir
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```
as:

```mlir
%unpack = tensor.unpack %cast
  inner_dims_pos = [0, 1]
  inner_tiles = [%c8, 1]
  into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
```

This leads to an Op verification failure because the folder does not
update the inner tile sizes in the unpack Op. This patch resolves the
issue.

Additional Changes:
* invalid.mlir: Fixes a typo.
* TensorOps.cpp: Removes unnecessary `(void)tileSize` and adds extra
  comments following this discussion:
  llvm#115772.
@banach-space banach-space force-pushed the andrzej/add_fold_cast_into_unpack branch from 1dcc023 to 1bc2d8e Compare December 31, 2024 14:38
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some nits. Thanks for all the refactoring so far!

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp Outdated Show resolved Hide resolved
Comment on lines 4912 to 4913
if (Attribute attr =
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you add a local variable (similar to what you did above with shape) for the second iterator's value (e.g., something like tile)? I think it makes it more clear what the iterator is when reading the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! And to keep things consistent, let me update FoldTensorCastPackOp as well.

Comment on lines +4884 to +4885
struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Most of the logic in this function is the same as for tensor.pack, but with the source type instead of the dest type. Could you refactor the logic a bit to try to share code from a single function (mainly for finding the new mixed tile sizes)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, sending update shortly. Thanks for the suggestion!

%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe use the same style as the test below for the attribute? (i.e., {test_attr} instead of {some_attr})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've actually followed the test above

// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>,
// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>,
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
// CHECK-SAME: some_attr
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
func.func @fold_cast_pack_dynamic_tile_size(
%dest: tensor<1x1x8x1xi32>,
%src: tensor<7x?xi32>,
%pad: i32) -> tensor<1x1x8x1xi32> {
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%c8 = arith.constant 8 : index
%pack = tensor.pack %src padding_value(%pad : i32)
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
return %res : tensor<1x1x8x1xi32>
}

😂 Let me unify this.

@banach-space banach-space merged commit 9f6a1dd into llvm:main Jan 3, 2025
8 checks passed
@banach-space banach-space deleted the andrzej/add_fold_cast_into_unpack branch January 3, 2025 10:16
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jan 3, 2025
Adds an end-to-end test for `tensor.unpack` with dynamic inner tile sizes.
While relatively simple (e.g., no vectorization), this example required
a few fixes in handling `tensor.unpack` (and similar fixes for
`tensor.pack` before that):

* llvm#119379, llvm#121393, llvm#121400.

The end goal for this test is to incrementally increase its complexity
and to work towards scalable tile sizes.

Note, this PR complements llvm#115698 in which similar test for
`tensor.pack` was added.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants