diff --git a/BUILD b/BUILD index 33b0850..67765eb 100644 --- a/BUILD +++ b/BUILD @@ -26,6 +26,7 @@ td_library( ], includes = ["include"], deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", ], @@ -58,11 +59,17 @@ gentbl_cc_library( "include/mlir-tcp/Dialect/IR/TcpDialect.cpp.inc", ), ( - ["-gen-attrdef-decls"], + [ + "-gen-attrdef-decls", + "-attrdefs-dialect=tcp", + ], "include/mlir-tcp/Dialect/IR/TcpAttrs.h.inc", ), ( - ["-gen-attrdef-defs"], + [ + "-gen-attrdef-defs", + "-attrdefs-dialect=tcp", + ], "include/mlir-tcp/Dialect/IR/TcpAttrs.cpp.inc", ), ( @@ -142,6 +149,7 @@ gentbl_cc_library( cc_library( name = "TcpDialectPasses", srcs = [ + "lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp", "lib/Dialect/Transforms/FuseTcpOpsPass.cpp", "lib/Dialect/Transforms/FusionPatterns.cpp", "lib/Dialect/Transforms/IsolateGroupOpsPass.cpp", @@ -151,6 +159,7 @@ cc_library( "lib/Dialect/Transforms/VerifyTcpBackendContractPass.cpp", ], hdrs = [ + "include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h", "include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h", "include/mlir-tcp/Dialect/Transforms/FusionPatterns.h", "include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h", diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 3408d64..45fa0e9 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -11,6 +11,7 @@ #define TCP_OPS include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir-tcp/Dialect/IR/TcpBase.td" include "mlir-tcp/Dialect/IR/TcpEnums.td" @@ -640,4 +641,69 @@ def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, Sa let assemblyFormat = "$in `starts` `(` $starts `)` `sizes` `(` $sizes `)` `strides` `(` $strides `)` attr-dict `:` type($in) `->` type($out)"; } +//===----------------------------------------------------------------------===// +// Symbolic shape modeling ops for TorchDynamo frontend. +//===----------------------------------------------------------------------===// + +def Tcp_SymbolicIntOp : Tcp_Op<"symbolic_int", [Pure]> { + + let summary = "Symbolic int representing a dynamic dimension"; + + let description = [{ + The `tcp.symbolic_int` operation captures a dynamic dimension on the + global function arguments. It associates the shape symbols (i.e. "s0", + "s1") with the global SSA values (i.e. `%0`, `%1`) that is then + referenced to bind shapes on op results. + + Additionally, the operation annotates `min_val` and `max_val` attributes + denoting the range constraints for the dynamic dimension. This may be + useful for modeling runtime shape guards, or compile-time optimizations + based on the shape bounds (min, opt, max) on results of ops / regions. + + Example: + ``` + %0 = tcp.symbolic_int "s0" {min_val = 5, max_val = 10} : i64 + %1 = tcp.symbolic_int "s1" {min_val = 2, max_val = 20} : i64 + ``` + }]; + + let arguments = (ins + StrAttr:$symbol_name, + I64Attr:$min_val, + I64Attr:$max_val + ); + let results = (outs + AnySignlessInteger:$result + ); + let assemblyFormat = [{ + $symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result) + }]; +} + +def Tcp_BindSymbolicShapeOp : Tcp_Op<"bind_symbolic_shape", []> { + let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols"; + let description = [{ + The `tcp.bind_symbolic_shape` operation binds shape expressions + useful to compute the dynamic dimensions of a tensor. It takes a + variadic of SSA symbols that map 1:1 to the local symbols declared + in the affine map. The affine map contains a list of affine shape + expressions for each dim where the terminals are from the declared + symbols. + + Example: + ``` + tcp.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor + tcp.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : tensor + ``` + }]; + let arguments = (ins + Tcp_Tensor:$operand, + Variadic:$shape_symbols, + Builtin_AffineMapAttr:$shape_expressions + ); + let results = (outs); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + #endif // TCP_OPS diff --git a/include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h b/include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h new file mode 100644 index 0000000..d7059e9 --- /dev/null +++ b/include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h @@ -0,0 +1,22 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir::tcp { + +std::unique_ptr> +createDropSymbolicShapeOpsPass(); + +} // namespace mlir::tcp diff --git a/include/mlir-tcp/Dialect/Transforms/Passes.td b/include/mlir-tcp/Dialect/Transforms/Passes.td index c0bc96f..58dbdbb 100644 --- a/include/mlir-tcp/Dialect/Transforms/Passes.td +++ b/include/mlir-tcp/Dialect/Transforms/Passes.td @@ -38,4 +38,11 @@ def DecomposeTensorOps : Pass<"decompose-tensor-ops", "func::FuncOp"> { let constructor = "mlir::tcp::createDecomposeTensorOpsPass()"; } +// \brief This pass removes any unused symbolic shape ops. +// We discard remaining bind shape ops during backend lowering. +def DropSymbolicShapeOps : Pass<"drop-symbolic-shape-ops", "func::FuncOp"> { + let summary = "Removes all remaining symbolic shape ops."; + let constructor = "mlir::tcp::createDropSymbolicShapeOpsPass()"; +} + #endif // TCP_PASSES diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index 6ab9605..a58e307 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -275,12 +275,51 @@ class ConvertAtenZerosOnesLikeOp : public OpConversionPattern { } }; +class ConvertSymbolicIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Torch::SymbolicIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = getTypeConverter()->convertType(op.getType()); + + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getSymbolNameAttr(), adaptor.getMinValAttr(), + adaptor.getMaxValAttr()); + return success(); + } +}; + +class ConvertBindSymbolicShapeOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Torch::BindSymbolicShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperand(), adaptor.getShapeSymbols(), + adaptor.getShapeExpressionsAttr()); + return success(); + } +}; + } // namespace void torch_to_tcp::populateMiscPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( + typeConverter, patterns, target, convertTorchOpsSet); + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( + typeConverter, patterns, target, convertTorchOpsSet); + #define INSERT_ATEN_MISC_OP_PATTERN(AtenOp) \ torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ typeConverter, patterns, target, convertTorchOpsSet) diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 63250e8..3a8e5af 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -170,4 +170,67 @@ LogicalResult CastOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// BindSymbolicShapeOp +//===----------------------------------------------------------------------===// + +// +// tcp.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] -> +// (s0, s1 * 2 + s2, 3)> : tensor +// + +ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + SmallVector shapeSymbols; + AffineMapAttr shapeExpressions; + Type operandType; + + if (parser.parseOperand(operand) || parser.parseComma() || + parser.parseLSquare() || parser.parseOperandList(shapeSymbols) || + parser.parseRSquare() || parser.parseComma() || + parser.parseAttribute(shapeExpressions, "shape_expressions", + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(operandType)) { + return failure(); + } + + if (parser.resolveOperand(operand, operandType, result.operands) || + parser.resolveOperands(shapeSymbols, + parser.getBuilder().getType(64), + result.operands)) { + return failure(); + } + + return success(); +} + +// Use a custom printer here to avoid the AffineMap from getting hoisted +// when printed. This makes it so the AffineMap is printed inline with the op. +void BindSymbolicShapeOp::print(OpAsmPrinter &p) { + p << " " << getOperand() << ", ["; + llvm::interleaveComma(getShapeSymbols(), p); + p << "], " + << "affine_map<" << getShapeExpressions().getValue() << ">"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"shape_expressions"}); + p << " : " << getOperand().getType(); +} + +LogicalResult BindSymbolicShapeOp::verify() { + if (getShapeSymbols().empty()) + return emitOpError() << "requires non-empty shapeSymbols"; + + for (auto symbol : getShapeSymbols()) { + Operation *definingOp = symbol.getDefiningOp(); + if (!isa(definingOp)) { + return emitOpError() + << "shape symbol must be produced by a SymbolicIntOp"; + } + } + + return success(); +} + } // namespace mlir::tcp diff --git a/lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp b/lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp new file mode 100644 index 0000000..c3a1f06 --- /dev/null +++ b/lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp @@ -0,0 +1,59 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h" + +#include "mlir-tcp/Dialect/IR/TcpDialect.h" +#include "mlir-tcp/Dialect/IR/TcpOps.h" +#include "mlir-tcp/Dialect/Transforms/Passes.h" + +#include "./PassDetail.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace mlir::tcp { + +namespace { + +class RemoveBindSymbolicShapeOps + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tcp::BindSymbolicShapeOp op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +class DropSymbolicShapeOpsPass + : public DropSymbolicShapeOpsBase { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> createDropSymbolicShapeOpsPass() { + return std::make_unique(); +} + +} // namespace mlir::tcp diff --git a/lib/Dialect/Transforms/Passes.cpp b/lib/Dialect/Transforms/Passes.cpp index 18bb5bb..dd1dab6 100644 --- a/lib/Dialect/Transforms/Passes.cpp +++ b/lib/Dialect/Transforms/Passes.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tcp/Dialect/Transforms/Passes.h" +#include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h" #include "mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h" #include "mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h" #include "mlir-tcp/Dialect/Transforms/TransformTensorOps.h" diff --git a/lib/Pipeline/Pipeline.cpp b/lib/Pipeline/Pipeline.cpp index 8f37c35..6387d91 100644 --- a/lib/Pipeline/Pipeline.cpp +++ b/lib/Pipeline/Pipeline.cpp @@ -14,6 +14,7 @@ #include "mlir-tcp/Conversion/TcpToTensor/TcpToTensor.h" #include "mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h" #include "mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h" +#include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h" #include "mlir-tcp/Dialect/Transforms/TransformTensorOps.h" #include "mlir-tcp/Dialect/Transforms/VerifyTcpBackendContractPass.h" @@ -65,6 +66,9 @@ static void createTorchBackendToTcpBackendPipeline(OpPassManager &pm) { } static void createTcpToLlvmPipeline(OpPassManager &pm) { + // Drop TCP symbolic shape ops for dynamic dims + pm.addNestedPass(tcp::createDropSymbolicShapeOpsPass()); + // TCP transformations. pm.addNestedPass(tcp::createDecomposeTensorOpsPass()); diff --git a/test/Conversion/TorchToTcp/misc.mlir b/test/Conversion/TorchToTcp/misc.mlir index cdc47d5..d6c1972 100644 --- a/test/Conversion/TorchToTcp/misc.mlir +++ b/test/Conversion/TorchToTcp/misc.mlir @@ -418,3 +418,40 @@ func.func @torch.aten.broadcast_to_dynamic_dim(%arg0: !torch.vtensor<[1,2],f32>, %2 = torch.aten.broadcast_to %arg0, %1 : !torch.vtensor<[1,2],f32>, !torch.list -> !torch.vtensor<[?,2],f32> return %2 : !torch.vtensor<[?,2],f32> } + +// ----- + +// CHECK-LABEL: @symbolic_shape_ops( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,3],f32>, %[[ARG2:.*]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +// CHECK: %[[S0:.*]] = tcp.symbolic_int "s0" {min_val = 5, max_val = 10} : i64 +// CHECK: %[[S1:.*]] = tcp.symbolic_int "s1" {min_val = 0, max_val = 100} : i64 +// CHECK: %[[S3:.*]] = tcp.symbolic_int "s3" {min_val = 0, max_val = 50} : i64 +// CHECK: %[[S5:.*]] = tcp.symbolic_int "s5" {min_val = 0, max_val = {{[0-9]+}}} : i64 +// CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor +// CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor +// CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S5]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor +// CHECK: %[[TANH:.*]] = tcp.tanh %{{.*}} : tensor -> tensor +// CHECK: tcp.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor +// CHECK: %[[SIGM:.*]] = tcp.sigmoid %{{.*}} : tensor -> tensor +// CHECK: tcp.bind_symbolic_shape %[[SIGM]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor +// CHECK: %[[CAT:.*]] = tensor.concat dim(1) %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (tensor, tensor, tensor, tensor) -> tensor +// CHECK: tcp.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S3]], %[[S5]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : tensor +// CHECK: return %{{.*}} : !torch.vtensor<[?,?,3],f32> +func.func @symbolic_shape_ops(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>, %arg2: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { + %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int + %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int + %3 = torch.symbolic_int "s5" {min_val = 0, max_val = 9223372036854775806} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %arg1, [%0, %2], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %arg2, [%0, %3], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + %4 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + %5 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %5, [%0, %2], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + %6 = torch.prim.ListConstruct %4, %4, %5, %arg2 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list + %int1 = torch.constant.int 1 + %7 = torch.aten.cat %6, %int1 : !torch.list, !torch.int -> !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %7, [%0, %1, %2, %3], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> + return %7 : !torch.vtensor<[?,?,3],f32> +} diff --git a/test/Dialect/canonicalize.mlir b/test/Dialect/canonicalize.mlir index 9d72699..d95d625 100644 --- a/test/Dialect/canonicalize.mlir +++ b/test/Dialect/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: tcp-opt %s -canonicalize | FileCheck %s +// RUN: tcp-opt %s -canonicalize -split-input-file | FileCheck %s // CHECK-LABEL: func.func @test_constant_folding() -> tensor // CHECK: %[[CONST0:.*]] = tcp.const {value = dense<2.500000e+00> : tensor} : tensor @@ -10,3 +10,20 @@ func.func @test_constant_folding() -> tensor { %2 = tcp.mul %0, %1 : tensor, tensor -> tensor return %2 : tensor } + +// ----- + +// CHECK-LABEL: func.func @test_tcp_symbolic_int$canonicalize( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor { +// CHECK: %[[S0:.*]] = tcp.symbolic_int "s0" {min_val = 3, max_val = 6} : i64 +// CHECK-NOT: %[[S1:.*]] = tcp.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : i64 +// CHECK: tcp.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : tensor +// CHECK: tcp.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : tensor +// CHECK: return %[[ARG0]] : tensor +func.func @test_tcp_symbolic_int$canonicalize(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcp.symbolic_int "s0" {min_val = 3, max_val = 6} : i64 + %1 = tcp.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : i64 + tcp.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0)> : tensor + tcp.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0 + 1)> : tensor + return %arg0 : tensor +} diff --git a/test/Dialect/drop_symbolic_shape_ops.mlir b/test/Dialect/drop_symbolic_shape_ops.mlir new file mode 100644 index 0000000..2c54b66 --- /dev/null +++ b/test/Dialect/drop_symbolic_shape_ops.mlir @@ -0,0 +1,16 @@ +// RUN: tcp-opt %s -drop-symbolic-shape-ops | FileCheck %s + +// CHECK-LABEL: func.func @test_drop_symbolic_shape_ops( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor { +// CHECK-NOT: %[[S0:.*]] = tcp.symbolic_int "s0" {min_val = 3, max_val = 6} : i64 +// CHECK-NOT: %[[S1:.*]] = tcp.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : i64 +// CHECK-NOT: tcp.bind_symbolic_shape %[[ARG0]], [%{{.*}}], affine_map<()[s0] -> (s0)> : tensor +// CHECK-NOT: tcp.bind_symbolic_shape %[[ARG1]], [%{{.*}}], affine_map<()[s0] -> (s0 + 1)> : tensor +// CHECK: return %[[ARG0]] : tensor +func.func @test_drop_symbolic_shape_ops(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcp.symbolic_int "s0" {min_val = 3, max_val = 6} : i64 + %1 = tcp.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : i64 + tcp.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0)> : tensor + tcp.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0 + 1)> : tensor + return %arg0 : tensor +} diff --git a/test/python_lit/fx_import/custom_op_test.py b/test/python_lit/fx_import/custom_op_test.py new file mode 100644 index 0000000..d4105c2 --- /dev/null +++ b/test/python_lit/fx_import/custom_op_test.py @@ -0,0 +1,86 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch.nn as nn +from torch.export import Dim +from torch.library import Library, impl, impl_abstract + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat_custom_op(): + + m = Library("my_custom_library", "DEF") + m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor") + + @impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd") + def custom_op(x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + @impl_abstract("my_custom_library::tanh_sigmoid_cat_op") + def custom_op_meta(x, y, z): + result = custom_op(x, y, z) + return torch.empty_like(result) + + class TanhSigmoidCatCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCatCustomOp(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) diff --git a/test/python_lit/fx_import/symbolic_shape_expr_test.py b/test/python_lit/fx_import/symbolic_shape_expr_test.py new file mode 100644 index 0000000..fd207a8 --- /dev/null +++ b/test/python_lit/fx_import/symbolic_shape_expr_test.py @@ -0,0 +1,78 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch.export +import torch.nn as nn +from torch.export import Dim + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list +# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list, !torch.int -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat(): + class TanhSigmoidCat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCat(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) diff --git a/tools/aot/torch_exporter_harness.py b/tools/aot/torch_exporter_harness.py index 98c25ef..c30060e 100644 --- a/tools/aot/torch_exporter_harness.py +++ b/tools/aot/torch_exporter_harness.py @@ -47,6 +47,7 @@ def main(): loader_result.model, *loader_result.inputs, # unpack list of input tensors dynamic_shapes=loader_result.dynamic_shapes, + import_symbolic_shape_expressions=True, func_name=loader_result.func_name, )