Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Introduce FoldTensorCastUnPackOp #121393

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
// Already a constant
newMixedTileSizes.push_back(std::get<1>(it));
} else {
int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
assert(tileSize == shape && "tile size and dim size don't match!");
(void)tileSize;
assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
"tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
}
}

// Clone op.
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
// this point. However, in practice, we use them for things that we'd like
// to preserve. Implement a better abstraction.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
Expand All @@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
}
};

/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
/// `tensor.cast` has source that is more static than the consuming op.
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
banach-space marked this conversation as resolved.
Show resolved Hide resolved
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
/// ```
struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;
Comment on lines +4902 to +4903
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Most of the logic in this function is the same as for tensor.pack, but with the source type instead of the dest type. Could you refactor the logic a bit to try to share code from a single function (mainly for finding the new mixed tile sizes)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, sending update shortly. Thanks for the suggestion!


LogicalResult matchAndRewrite(UnPackOp op,
PatternRewriter &rewriter) const override {
if (!foldTensorCastPrecondition(op))
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
Value sourceTensor = newOperands[0];

// Get the updated mixed-tile-sizes attribute.
SmallVector<OpFoldResult> newMixedTileSizes;
for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
.getShape()
.take_back(op.getMixedTiles().size()),
op.getMixedTiles())) {
int64_t shape = std::get<0>(it);
// If the current source shape is dynamic, just preserve this mixed
// size.
if (shape == ShapedType::kDynamic) {
newMixedTileSizes.push_back(std::get<1>(it));
continue;
}

// If the current source is static, update the dynamic mixed-size
// (provided the original value is dynamic).
if (Attribute attr =
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you add a local variable (similar to what you did above with shape) for the second iterator's value (e.g., something like tile)? I think it makes it more clear what the iterator is when reading the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! And to keep things consistent, let me update FoldTensorCastPackOp as well.

// Already a constant
newMixedTileSizes.push_back(std::get<1>(it));
} else {
assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
"tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
}
}

// Clone op.
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
// this point. However, in practice, we use them for things that we'd like
// to preserve. Implement a better abstraction.
UnPackOp newOp = rewriter.create<UnPackOp>(
op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getOuterDimsPerm());
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());

// Replace op.
Value oldResult = op.getResult();
Value newResult = newOp.getResult();
Value replacement = (newResult.getType() != oldResult.getType())
? rewriter.create<tensor::CastOp>(
op->getLoc(), oldResult.getType(), newResult)
: newResult;

rewriter.replaceOp(op, {replacement});

return success();
}
};

/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
Expand All @@ -4890,7 +4969,8 @@ struct FoldTensorCastProducerOp
PatternRewriter &rewriter) const override {

// Reject tensor::PackOp - there's dedicated pattern for that instead.
if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
if (!foldTensorCastPrecondition(op) ||
isa<tensor::PackOp, tensor::UnPackOp>(*op))
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
Expand Down Expand Up @@ -4923,6 +5003,7 @@ struct FoldTensorCastProducerOp
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
results.add<FoldTensorCastUnPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}

Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}

// -----

// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
Expand Down Expand Up @@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size(

// -----

// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
// CHECK: return %[[RES]] : tensor<7x?xi32>
func.func @fold_cast_unpack_dynamic_tile_size(
%src: tensor<1x1x8x1xi32>,
%res: tensor<7x?xi32>) -> tensor<7x?xi32> {

%cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%c8 = arith.constant 8 : index
%unpack = tensor.unpack %cast
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe use the same style as the test below for the attribute? (i.e., {test_attr} instead of {some_attr})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've actually followed the test above

// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>,
// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>,
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
// CHECK-SAME: some_attr
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
func.func @fold_cast_pack_dynamic_tile_size(
%dest: tensor<1x1x8x1xi32>,
%src: tensor<7x?xi32>,
%pad: i32) -> tensor<1x1x8x1xi32> {
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%c8 = arith.constant 8 : index
%pack = tensor.pack %src padding_value(%pad : i32)
inner_dims_pos = [0, 1]
inner_tiles = [%c8, 1]
into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
return %res : tensor<1x1x8x1xi32>
}

😂 Let me unify this.

return %unpack : tensor<7x?xi32>
}

// -----

// CHECK-LABEL: func.func @pack_dont_drop_attributes(
// CHECK: tensor.pack {{.*}} {test_attr}
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor

// -----

func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
Expand Down
Loading