diff --git a/.gitmodules b/.gitmodules index 55e20273a2..8c5cd97073 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,7 @@ [submodule "third_party/mlir-hlo"] path = third_party/mlir-hlo url = https://github.com/tensorflow/mlir-hlo.git +[submodule "third_party/torch-mlir"] + path = third_party/torch-mlir + url = https://github.com/nod-ai/torch-mlir.git + branch = external-builder diff --git a/CMakeLists.txt b/CMakeLists.txt index 23789a4c29..075740ef25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,7 @@ project(onnx-mlir) option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON) option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF) option(ONNX_MLIR_ENABLE_MHLO "Enable MHLO support." ON) +option(ONNX_MLIR_ENABLE_TORCH "Enable Torch support." ON) option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF) option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON) @@ -152,6 +153,13 @@ if (ONNX_MLIR_ENABLE_MHLO) add_subdirectory(third_party/mlir-hlo EXCLUDE_FROM_ALL) endif() +if (ONNX_MLIR_ENABLE_TORCH) + set(TORCH_MLIR_ENABLE_MHLO OFF) + set(TORCH_MLIR_ENABLE_LTC OFF) + set(TORCH_MLIR_OUT_OF_TREE_BUILD ON) + add_subdirectory(third_party/torch-mlir EXCLUDE_FROM_ALL) +endif() + if (NOT TARGET benchmark) set(BENCHMARK_USE_BUNDLED_GTEST OFF) set(BENCHMARK_ENABLE_GTEST_TESTS OFF) @@ -185,6 +193,10 @@ if (ONNX_MLIR_ENABLE_MHLO) add_compile_definitions(ONNX_MLIR_ENABLE_MHLO) endif() +if (ONNX_MLIR_ENABLE_TORCH) + add_compile_definitions(ONNX_MLIR_ENABLE_TORCH) +endif() + add_subdirectory(utils) add_subdirectory(include) add_subdirectory(src) diff --git a/src/Conversion/CMakeLists.txt b/src/Conversion/CMakeLists.txt index 0f95fdfa63..af89666495 100644 --- a/src/Conversion/CMakeLists.txt +++ b/src/Conversion/CMakeLists.txt @@ -9,3 +9,7 @@ add_subdirectory(ONNXToTOSA) if (ONNX_MLIR_ENABLE_MHLO) add_subdirectory(ONNXToMhlo) endif() + +if (ONNX_MLIR_ENABLE_TORCH) + add_subdirectory(ONNXToTorch) +endif() diff --git a/src/Conversion/ONNXToTorch/CMakeLists.txt b/src/Conversion/ONNXToTorch/CMakeLists.txt new file mode 100644 index 0000000000..e11e841023 --- /dev/null +++ b/src/Conversion/ONNXToTorch/CMakeLists.txt @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Please keep in alphabetical order. + +install(TARGETS + TorchMLIRTorchDialect + TorchMLIRTorchUtils + ) + +add_onnx_mlir_library(OMONNXToTorch + ConvertONNXToTorch.cpp + ConvertONNXToTorchPipeline.cpp + EraseONNXEntryPoint.cpp + ONNXToTorchCommon.cpp + TypeConversion/TorchTypeConversion.cpp + TypeConversion/TorchTypeConversionPasses.cpp + + Math/Elementwise.cpp + + Tensor/Constant.cpp + + LINK_LIBS PUBLIC + TorchMLIRTorchDialect + TorchMLIRTorchUtils + OMONNXOps + ) diff --git a/src/Conversion/ONNXToTorch/ConvertONNXToTorch.cpp b/src/Conversion/ONNXToTorch/ConvertONNXToTorch.cpp new file mode 100644 index 0000000000..c71bfd5454 --- /dev/null +++ b/src/Conversion/ONNXToTorch/ConvertONNXToTorch.cpp @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ConvertONNXToTorch.cpp - ONNX dialects to Torch lowering +//-------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// =============================================================================== +// +// This file implements the lowering of frontend operations to Torch backend IR. +// +//===------------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +void populateONNXToTorchConversionPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + // Math + populateLoweringONNXElementwiseOpToTorchPattern(typeConverter, patterns, ctx); + populateLoweringONNXConstantOpToTorchPattern(typeConverter, patterns, ctx); +} + +//===----------------------------------------------------------------------===// +// Frontend to Mhlo Dialect lowering pass +//===----------------------------------------------------------------------===// + +struct FrontendToTorchLoweringPass + : public PassWrapper> { + + StringRef getArgument() const override { return "convert-onnx-to-torch"; } + + StringRef getDescription() const override { + return "Lower frontend ops to Torch dialect."; + } + + // Make sure that we have a valid default constructor and copy + // constructor to make sure that the options are initialized properly. + FrontendToTorchLoweringPass() = default; + FrontendToTorchLoweringPass(const FrontendToTorchLoweringPass &pass) + : PassWrapper>() {} + + void runOnOperation() final; +}; + +void FrontendToTorchLoweringPass::runOnOperation() { + ModuleOp module = getOperation(); + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + onnx_mlir::setupTorchTypeConversion(target, typeConverter); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the frontend operations. + RewritePatternSet patterns(&getContext()); + + // Define patterns. + populateONNXToTorchConversionPattern(typeConverter, patterns, &getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr createLowerToTorchPass() { + return std::make_unique(); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTorch/ConvertONNXToTorchPipeline.cpp b/src/Conversion/ONNXToTorch/ConvertONNXToTorchPipeline.cpp new file mode 100644 index 0000000000..b12c73a4aa --- /dev/null +++ b/src/Conversion/ONNXToTorch/ConvertONNXToTorchPipeline.cpp @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ConvertONNXToTorchPipeline.cpp - ONNX dialects to Torch lowering +// pipeline -------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ================================================================================================ +// +// This file registers the pipeline for converting ONNX to Torch Backend IR +// +//===-----------------------------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +void registerONNXFrontendToTorchBackendPasses() { + PassPipelineRegistration<>("convert-onnx-to-torch-pipeline", + "Pipeline converting ONNX to Torch dialect.", + onnx_mlir::createONNXFrontendToTorchBackendPasses); +} + +void createONNXFrontendToTorchBackendPasses(OpPassManager &pm) { + pm.addPass(createLowerToTorchPass()); + pm.addPass(createFuncTorchTypeConversionPass()); + pm.addPass(createFinalizingTorchTypeConversionPass()); + pm.addPass(createEraseONNXEntryPointPass()); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTorch/EraseONNXEntryPoint.cpp b/src/Conversion/ONNXToTorch/EraseONNXEntryPoint.cpp new file mode 100644 index 0000000000..765ccf7df1 --- /dev/null +++ b/src/Conversion/ONNXToTorch/EraseONNXEntryPoint.cpp @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====----- EraseModuleInitializer.cpp - ONNX dialects to Torch lowering +//---------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ==================================================================================== +// +// This file implements a pass for removing the ONNXEntryPointOp for +// compatibility when converting to Torch backend IR. +// +//===------------------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +struct EraseONNXEntryPointPass + : public PassWrapper> { + + StringRef getArgument() const override { return "erase-onnx-entry-point"; } + + StringRef getDescription() const override { + return "Erase ONNXEntryPointOp."; + } + + // Make sure that we have a valid default constructor and copy + // constructor to make sure that the options are initialized properly. + EraseONNXEntryPointPass() = default; + EraseONNXEntryPointPass(const EraseONNXEntryPointPass &pass) + : PassWrapper>() {} + + void runOnOperation() override { + auto walkResult = getOperation().walk([](ONNXEntryPointOp op) { + op.erase(); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) { + return signalPassFailure(); + } + } +}; +} // namespace onnx_mlir + +// std::unique_ptr> +std::unique_ptr onnx_mlir::createEraseONNXEntryPointPass() { + return std::make_unique(); +} diff --git a/src/Conversion/ONNXToTorch/Math/Elementwise.cpp b/src/Conversion/ONNXToTorch/Math/Elementwise.cpp new file mode 100644 index 0000000000..38106622b0 --- /dev/null +++ b/src/Conversion/ONNXToTorch/Math/Elementwise.cpp @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Elementwise.cpp - Elementwise Ops -------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers ONNX element-wise operators to Torch dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +using namespace mlir; +using namespace mlir::torch; + +namespace onnx_mlir { + +namespace { + +// AtenAddOp requires an additional alpha parameter and thus requires a unique +// lowering +class ConvertONNXAddOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + auto newResultType = + getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, newResultType, adaptor.A(), adaptor.B(), one); + return success(); + } +}; +} // namespace + +void populateLoweringONNXElementwiseOpToTorchPattern( + TypeConverter &typeConverter, RewritePatternSet &patterns, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTorch/ONNXToTorchCommon.cpp b/src/Conversion/ONNXToTorch/ONNXToTorchCommon.cpp new file mode 100644 index 0000000000..91e9f533ab --- /dev/null +++ b/src/Conversion/ONNXToTorch/ONNXToTorchCommon.cpp @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====----- ONNXToTorchCommon.cpp - ONNX dialects to Torch lowering +//---------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// =============================================================================== +// +// This file contains common code shared by the functions performing the +// lowering to the Torch backend dialect. +// +//===------------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +namespace onnx_mlir {} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp b/src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp new file mode 100644 index 0000000000..66a969835c --- /dev/null +++ b/src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ONNXToTorchCommon.hpp - ONNX dialects to Torch lowering +//--------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// =============================================================================== +// +// This file contains common code shared by the functions performing the +// lowering to the Torch dialect. +// +//===------------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" + +#include "src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.hpp" +#include "src/Dialect/Mlir/IndexExpr.hpp" +#include "src/Dialect/ONNX/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" +#include "src/Pass/Passes.hpp" +#include "src/Transform/ONNX/ConstPropHelper.hpp" + +namespace onnx_mlir { + +//===----------------------------------------------------------------------===// +// Functions for populating the conversion patterns for the lowerings. +//===----------------------------------------------------------------------===// + +// `Math` directory methods: +void populateLoweringONNXElementwiseOpToTorchPattern( + TypeConverter &, RewritePatternSet &, MLIRContext *); + +// `Tensor` directory methods: +void populateLoweringONNXConstantOpToTorchPattern( + TypeConverter &, RewritePatternSet &, MLIRContext *); +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTorch/Tensor/Constant.cpp b/src/Conversion/ONNXToTorch/Tensor/Constant.cpp new file mode 100644 index 0000000000..b8f0d64a92 --- /dev/null +++ b/src/Conversion/ONNXToTorch/Tensor/Constant.cpp @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Constant.cpp - ONNXConstantOp -------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ========================================================================= +// +// This file lowers ONNXConstantOp to Torch::NonValueTensorLiteralOp +// +//===------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +using namespace mlir; +using namespace mlir::torch; + +namespace onnx_mlir { + +namespace { +class ConvertONNXConstantOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.value().has_value()) + return rewriter.notifyMatchFailure( + op, "unimplemented: non-dense values are unsupported"); + ElementsAttr value = adaptor.valueAttr(); + auto newResultType = + getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, newResultType, value); + return success(); + } +}; +} // namespace + +void populateLoweringONNXConstantOpToTorchPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.add(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.cpp b/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.cpp new file mode 100644 index 0000000000..1023586bc1 --- /dev/null +++ b/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.cpp @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====----- TorchTypeConversion.cpp - ONNX types to Torch types conversion +//---------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ====================================================================================== +// +// This file contains code to setup type conversions from ONNX types (builtin) +// to Torch types (e.g. torch.tensor) +// +//===-------------------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +using namespace mlir; +using namespace mlir::torch; + +void onnx_mlir::getTorchTypeConversionDependentDialects( + DialectRegistry ®istry) { + // registry.insert(); +} + +//===----------------------------------------------------------------------===// +// Type conversion setup. +//===----------------------------------------------------------------------===// + +static torch::Torch::ValueTensorType getValueTensorFromBuiltinTensor( + TensorType type) { + auto context = type.getContext(); + if (type.isa()) { + return torch::Torch::ValueTensorType::get( + context, type.getShape(), type.getElementType()); + } + return torch::Torch::ValueTensorType::get( + context, None, type.getElementType()); +} + +static void setupTensorToValueTensorConversion( + ConversionTarget &target, TypeConverter &typeConverter) { + target.addLegalOp(); + typeConverter.addConversion([](TensorType type) -> Optional { + return getValueTensorFromBuiltinTensor(type); + }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, Torch::ValueTensorType type, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]) + .getResult(0); + }); + auto sourceMaterialization = [](OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) -> Value { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]) + .getResult(0); + }; + typeConverter.addSourceMaterialization(sourceMaterialization); + typeConverter.addArgumentMaterialization(sourceMaterialization); +} + +void onnx_mlir::setupTorchTypeConversion( + ConversionTarget &target, TypeConverter &typeConverter) { + setupTensorToValueTensorConversion(target, typeConverter); +} diff --git a/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.hpp b/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.hpp new file mode 100644 index 0000000000..88c99d4d0a --- /dev/null +++ b/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.hpp @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====----- TorchTypeConversion.hpp - ONNX types to Torch types conversion +//---------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ====================================================================================== +// +// This file contains code to setup type conversions from ONNX types (builtin) +// to Torch types (e.g. torch.tensor) +// +//===-------------------------------------------------------------------------------===// + +#ifndef ONNXMLIR_DIALECT_TORCHTYPECONVERSION_H +#define ONNXMLIR_DIALECT_TORCHTYPECONVERSION_H + +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace onnx_mlir { + +/// Get the dependent dialects which might be involved in a backend type +/// conversion. +void getTorchTypeConversionDependentDialects(DialectRegistry ®istry); + +void setupTorchTypeConversion( + ConversionTarget &target, TypeConverter &typeConverter); +} // namespace onnx_mlir + +#endif // ONNXMLIR_DIALECT_TORCHTYPECONVERSION_H diff --git a/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversionPasses.cpp b/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversionPasses.cpp new file mode 100644 index 0000000000..fd11c1bf8c --- /dev/null +++ b/src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversionPasses.cpp @@ -0,0 +1,179 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====----- TorchTypeConversionPasses.cpp - ONNX types to Torch types conversion +// passes ---------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// =================================================================================================== +// +// This file defines additional passes for finishing the function type +// conversion as well as finalizing the type conversion to Torch types. +// +//===--------------------------------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" + +using namespace mlir; +using namespace mlir::torch; +using namespace onnx_mlir; + +//===----------------------------------------------------------------------===// +// FuncTorchTypeConversionPass +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { +struct FuncTorchTypeConversionPass + : public PassWrapper> { + + StringRef getArgument() const override { + return "convert-function-types-to-torch-types"; + } + + StringRef getDescription() const override { + return "Convert types in function calls and definitions to torch types."; + } + + // Make sure that we have a valid default constructor and copy + // constructor to make sure that the options are initialized properly. + FuncTorchTypeConversionPass() = default; + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + typeConverter.addConversion([](Type type) { return type; }); + onnx_mlir::setupTorchTypeConversion(target, typeConverter); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + populateCallOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + target.addLegalOp(); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern( + op, typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); + }); + + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace onnx_mlir + +// std::unique_ptr> +std::unique_ptr onnx_mlir::createFuncTorchTypeConversionPass() { + return std::make_unique(); +} + +//===----------------------------------------------------------------------===// +// FinalizingTorchTypeConversionPass +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { +// In a finalizing conversion, we know that all of the source types have been +// converted to the destination types, so the materialization becomes an +// identity. +template +class FinalizeMaterialization : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getOperands()[0]); + return success(); + } +}; +} // namespace onnx_mlir + +template +static void setupFinalization(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter) { + target.addIllegalOp(); + patterns.add>( + typeConverter, patterns.getContext()); +} + +template +static void setupFinalization(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter) { + setupFinalization(target, patterns, typeConverter); + setupFinalization(target, patterns, typeConverter); +} + +namespace onnx_mlir { +struct FinalizingTorchTypeConversionPass + : public PassWrapper> { + + StringRef getArgument() const override { + return "finalize-torch-type-conversion"; + } + + StringRef getDescription() const override { + return "Finalize the conversion from builtin types to torch types."; + } + + // Make sure that we have a valid default constructor and copy + // constructor to make sure that the options are initialized properly. + FinalizingTorchTypeConversionPass() = default; + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + typeConverter.addConversion([](Type type) { return type; }); + onnx_mlir::setupTorchTypeConversion(target, typeConverter); + + // Mark materializations as illegal in this pass (since we are finalizing) + // and add patterns that eliminate them. + setupFinalization( + target, patterns, typeConverter); + + // If all result types are legal, and all block arguments are legal, then + // all types in the program are legal. + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents the patterns from updating + // the types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace onnx_mlir + +// std::unique_ptr> +std::unique_ptr +onnx_mlir::createFinalizingTorchTypeConversionPass() { + return std::make_unique(); +} diff --git a/src/InitOMPasses.hpp b/src/InitOMPasses.hpp index b6a8e5a09e..eefd45c6aa 100644 --- a/src/InitOMPasses.hpp +++ b/src/InitOMPasses.hpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "src/Pass/Passes.hpp" namespace onnx_mlir { @@ -110,6 +111,14 @@ void initOMPasses(int optLevel) { mlir::registerPass( []() -> std::unique_ptr { return createLowerToMhloPass(); }); #endif + +#ifdef ONNX_MLIR_ENABLE_TORCH + // mlir::registerPass( + // []() -> std::unique_ptr { return createLowerToTorchPass(); + // }); + + onnx_mlir::registerONNXFrontendToTorchBackendPasses(); +#endif } } // namespace onnx_mlir diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 2f7c8e4789..b7cd2110be 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -14,6 +14,7 @@ #pragma once +#include "mlir/Pass/PassManager.h" #include namespace mlir { @@ -67,6 +68,16 @@ std::unique_ptr createLowerToKrnlPass( std::unique_ptr createLowerToMhloPass(); #endif +#ifdef ONNX_MLIR_ENABLE_TORCH +/// Add passes for lowering to Torch Backend IR. +void createONNXFrontendToTorchBackendPasses(mlir::OpPassManager &pm); +void registerONNXFrontendToTorchBackendPasses(); +std::unique_ptr createLowerToTorchPass(); +std::unique_ptr createFuncTorchTypeConversionPass(); +std::unique_ptr createFinalizingTorchTypeConversionPass(); +std::unique_ptr createEraseONNXEntryPointPass(); +#endif + /// Pass for lowering krnl.dim operations to standard dialect. std::unique_ptr createDisconnectKrnlDimFromAllocPass(); diff --git a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp index 6cfe38d634..6167239bdc 100644 --- a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp +++ b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp @@ -38,6 +38,10 @@ #include "src/InitOMPasses.hpp" #include "src/Pass/Passes.hpp" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" + using namespace mlir; using namespace onnx_mlir; @@ -134,6 +138,8 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); // Initialize accelerators if they exist. onnx_mlir::accel::initAccelerators(maccel); diff --git a/test/mlir/CMakeLists.txt b/test/mlir/CMakeLists.txt index 91d5c77311..4afea0c84a 100644 --- a/test/mlir/CMakeLists.txt +++ b/test/mlir/CMakeLists.txt @@ -20,6 +20,12 @@ else() set(ONNX_MLIR_MHLO_ENABLED 0) endif() +if (ONNX_MLIR_ENABLE_TORCH) + set(ONNX_MLIR_TORCH_ENABLED 1) +else() + set(ONNX_MLIR_TORCH_ENABLED 0) +endif() + # Set LLVM_DEFAULT_EXTERNAL_LIT to an empty string to avoid warnings about the path # when using multi-config generators such as VS or Xcode set(LLVM_DEFAULT_EXTERNAL_LIT "") diff --git a/test/mlir/conversion/onnx_to_torch/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_torch/Math/Elementwise.mlir new file mode 100644 index 0000000000..2351780132 --- /dev/null +++ b/test/mlir/conversion/onnx_to_torch/Math/Elementwise.mlir @@ -0,0 +1,23 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-torch-pipeline --canonicalize %s -split-input-file | FileCheck %s + +func.func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_add +// CHECK-SAME: ([[PARAM_0_:%.+]]: !torch.vtensor<[10,10],f32>, [[PARAM_1_:%.+]]: !torch.vtensor<[10,10],f32>) -> !torch.vtensor<[10,10],f32> { +// CHECK-NEXT: [[INT_1_:%.+]] = torch.constant.int 1 +// CHECK-NEXT: [[VAR_0_:%.+]] = torch.aten.add.Tensor [[PARAM_0_]], [[PARAM_1_]], [[INT_1_]] : !torch.vtensor<[10,10],f32>, !torch.vtensor<[10,10],f32>, !torch.int -> !torch.vtensor<[10,10],f32> +// CHECK-NEXT: return [[VAR_0_]] : !torch.vtensor<[10,10],f32> +// CHECK-NEXT: } +} + +func.func @test_add_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "func.return"(%0) : (tensor) -> () +// CHECK-LABEL: func @test_add_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: !torch.vtensor<[?,10],f32>, [[PARAM_1_:%.+]]: !torch.vtensor<[?,10],f32>) -> !torch.vtensor<[?,10],f32> { +// CHECK-NEXT: [[INT_1_:%.+]] = torch.constant.int 1 +// CHECK-NEXT: [[VAR_0_:%.+]] = torch.aten.add.Tensor [[PARAM_0_]], [[PARAM_1_]], [[INT_1_]] : !torch.vtensor<[?,10],f32>, !torch.vtensor<[?,10],f32>, !torch.int -> !torch.vtensor<[?,10],f32> +// CHECK-NEXT: return [[VAR_0_]] : !torch.vtensor<[?,10],f32> +// CHECK-NEXT: } +} diff --git a/test/mlir/conversion/onnx_to_torch/Tensor/Constant.mlir b/test/mlir/conversion/onnx_to_torch/Tensor/Constant.mlir new file mode 100644 index 0000000000..ffc1e92ea6 --- /dev/null +++ b/test/mlir/conversion/onnx_to_torch/Tensor/Constant.mlir @@ -0,0 +1,49 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-torch-pipeline %s -split-input-file | FileCheck %s + +func.func @test_scalar_attr() -> tensor { + %0 = "onnx.Constant"() {value = dense<1.0> : tensor} : () -> tensor + %1 = "onnx.Constant"() {value = dense<2.0> : tensor} : () -> tensor + %2 = "onnx.Add"(%0, %1) : (tensor , tensor) -> tensor + "func.return"(%2) : (tensor) -> () +// CHECK-LABEL: @test_scalar_attr() -> !torch.vtensor<[],f32> +// CHECK-DAG: [[VAR_0_:%.+]] = torch.vtensor.literal(dense<1.000000e+00> : tensor) : !torch.vtensor<[],f32> +// CHECK-DAG: [[VAR_1_:%.+]] = torch.vtensor.literal(dense<2.000000e+00> : tensor) : !torch.vtensor<[],f32> +} + +// ----- + +func.func @test_single_value_attr() -> tensor<1xf32> { + %0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32> + %1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32> + %2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32> + "func.return"(%2) : (tensor<1xf32>) -> () +// CHECK-LABEL: @test_single_value_attr() -> !torch.vtensor<[1],f32> +// CHECK-DAG: [[VAR_0_:%.+]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> +// CHECK-DAG: [[VAR_1_:%.+]] = torch.vtensor.literal(dense<2.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> +} + +// ----- + +func.func @test_splat_attr() -> tensor<3xf32> { + %0 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<2.0> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Add"(%0, %1) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> + "func.return"(%2) : (tensor<3xf32>) -> () +// CHECK-LABEL: @test_splat_attr() -> !torch.vtensor<[3],f32> +// CHECK-DAG: [[VAR_0_:%.+]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> +// CHECK-DAG: [[VAR_1_:%.+]] = torch.vtensor.literal(dense<2.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> +} + +// ----- + +func.func @test_splat_nonsplat_attrs() -> tensor<3xf32> { + %0 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Add"(%0, %1) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> + "func.return"(%2) : (tensor<3xf32>) -> () +// CHECK-LABEL: @test_splat_nonsplat_attrs() -> !torch.vtensor<[3],f32> +// CHECK-DAG: [[VAR_0_:%.+]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> +// CHECK-DAG: [[VAR_1_:%.+]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32> +} + +// ----- diff --git a/test/mlir/conversion/onnx_to_torch/lit.local.cfg b/test/mlir/conversion/onnx_to_torch/lit.local.cfg new file mode 100644 index 0000000000..d8cf88dd77 --- /dev/null +++ b/test/mlir/conversion/onnx_to_torch/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_torch: + config.unsupported = True diff --git a/test/mlir/lit.site.cfg.py.in b/test/mlir/lit.site.cfg.py.in index 5136390119..3062b212fb 100644 --- a/test/mlir/lit.site.cfg.py.in +++ b/test/mlir/lit.site.cfg.py.in @@ -9,6 +9,7 @@ config.onnx_mlir_tools_dir = r"@ONNX_MLIR_TOOLS_DIR@" config.onnx_mlir_obj_root = r"@ONNX_MLIR_BIN_ROOT@" config.enable_mhlo = @ONNX_MLIR_MHLO_ENABLED@ +config.enable_torch = @ONNX_MLIR_TORCH_ENABLED@ config.enable_nnpa= 0x0@NNPA_LIT_ENABLED@ # Support substitution of the tools_dir with user parameters. This is diff --git a/third_party/torch-mlir b/third_party/torch-mlir new file mode 160000 index 0000000000..125a59decb --- /dev/null +++ b/third_party/torch-mlir @@ -0,0 +1 @@ +Subproject commit 125a59decb5d6e87d878b1097c35193df2c79690