Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize conversion passes from ONNX to Torch-MLIR backend contract #1731

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
[submodule "third_party/mlir-hlo"]
path = third_party/mlir-hlo
url = https://github.com/tensorflow/mlir-hlo.git
[submodule "third_party/torch-mlir"]
path = third_party/torch-mlir
url = https://github.com/nod-ai/torch-mlir.git
branch = external-builder
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ project(onnx-mlir)
option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON)
option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF)
option(ONNX_MLIR_ENABLE_MHLO "Enable MHLO support." ON)
option(ONNX_MLIR_ENABLE_TORCH "Enable Torch support." ON)
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)

Expand Down Expand Up @@ -152,6 +153,13 @@ if (ONNX_MLIR_ENABLE_MHLO)
add_subdirectory(third_party/mlir-hlo EXCLUDE_FROM_ALL)
endif()

if (ONNX_MLIR_ENABLE_TORCH)
set(TORCH_MLIR_ENABLE_MHLO OFF)
set(TORCH_MLIR_ENABLE_LTC OFF)
set(TORCH_MLIR_OUT_OF_TREE_BUILD ON)
add_subdirectory(third_party/torch-mlir EXCLUDE_FROM_ALL)
endif()

if (NOT TARGET benchmark)
set(BENCHMARK_USE_BUNDLED_GTEST OFF)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
Expand Down Expand Up @@ -185,6 +193,10 @@ if (ONNX_MLIR_ENABLE_MHLO)
add_compile_definitions(ONNX_MLIR_ENABLE_MHLO)
endif()

if (ONNX_MLIR_ENABLE_TORCH)
add_compile_definitions(ONNX_MLIR_ENABLE_TORCH)
endif()

add_subdirectory(utils)
add_subdirectory(include)
add_subdirectory(src)
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ add_subdirectory(ONNXToTOSA)
if (ONNX_MLIR_ENABLE_MHLO)
add_subdirectory(ONNXToMhlo)
endif()

if (ONNX_MLIR_ENABLE_TORCH)
add_subdirectory(ONNXToTorch)
endif()
26 changes: 26 additions & 0 deletions src/Conversion/ONNXToTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0

# Please keep in alphabetical order.

install(TARGETS
TorchMLIRTorchDialect
TorchMLIRTorchUtils
)

add_onnx_mlir_library(OMONNXToTorch
ConvertONNXToTorch.cpp
ConvertONNXToTorchPipeline.cpp
EraseONNXEntryPoint.cpp
ONNXToTorchCommon.cpp
TypeConversion/TorchTypeConversion.cpp
TypeConversion/TorchTypeConversionPasses.cpp

Math/Elementwise.cpp

Tensor/Constant.cpp

LINK_LIBS PUBLIC
TorchMLIRTorchDialect
TorchMLIRTorchUtils
OMONNXOps
)
84 changes: 84 additions & 0 deletions src/Conversion/ONNXToTorch/ConvertONNXToTorch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ConvertONNXToTorch.cpp - ONNX dialects to Torch lowering
//-------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ===============================================================================
//
// This file implements the lowering of frontend operations to Torch backend IR.
//
//===------------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

void populateONNXToTorchConversionPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
// Math
populateLoweringONNXElementwiseOpToTorchPattern(typeConverter, patterns, ctx);
populateLoweringONNXConstantOpToTorchPattern(typeConverter, patterns, ctx);
}

//===----------------------------------------------------------------------===//
// Frontend to Mhlo Dialect lowering pass
//===----------------------------------------------------------------------===//

struct FrontendToTorchLoweringPass
: public PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>> {

StringRef getArgument() const override { return "convert-onnx-to-torch"; }

StringRef getDescription() const override {
return "Lower frontend ops to Torch dialect.";
}

// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
FrontendToTorchLoweringPass() = default;
FrontendToTorchLoweringPass(const FrontendToTorchLoweringPass &pass)
: PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>>() {}

void runOnOperation() final;
};

void FrontendToTorchLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());

// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target.addLegalDialect<torch::Torch::TorchDialect, func::FuncDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
onnx_mlir::setupTorchTypeConversion(target, typeConverter);

// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the frontend operations.
RewritePatternSet patterns(&getContext());

// Define patterns.
populateONNXToTorchConversionPattern(typeConverter, patterns, &getContext());

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

std::unique_ptr<Pass> createLowerToTorchPass() {
return std::make_unique<FrontendToTorchLoweringPass>();
}

} // namespace onnx_mlir
35 changes: 35 additions & 0 deletions src/Conversion/ONNXToTorch/ConvertONNXToTorchPipeline.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ConvertONNXToTorchPipeline.cpp - ONNX dialects to Torch lowering
// pipeline -------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ================================================================================================
//
// This file registers the pipeline for converting ONNX to Torch Backend IR
//
//===-----------------------------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

void registerONNXFrontendToTorchBackendPasses() {
PassPipelineRegistration<>("convert-onnx-to-torch-pipeline",
"Pipeline converting ONNX to Torch dialect.",
onnx_mlir::createONNXFrontendToTorchBackendPasses);
}

void createONNXFrontendToTorchBackendPasses(OpPassManager &pm) {
pm.addPass(createLowerToTorchPass());
pm.addPass(createFuncTorchTypeConversionPass());
pm.addPass(createFinalizingTorchTypeConversionPass());
pm.addPass(createEraseONNXEntryPointPass());
}

} // namespace onnx_mlir
59 changes: 59 additions & 0 deletions src/Conversion/ONNXToTorch/EraseONNXEntryPoint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====----- EraseModuleInitializer.cpp - ONNX dialects to Torch lowering
//---------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ====================================================================================
//
// This file implements a pass for removing the ONNXEntryPointOp for
// compatibility when converting to Torch backend IR.
//
//===------------------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;
using namespace onnx_mlir;

namespace onnx_mlir {
struct EraseONNXEntryPointPass
: public PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>> {

StringRef getArgument() const override { return "erase-onnx-entry-point"; }

StringRef getDescription() const override {
return "Erase ONNXEntryPointOp.";
}

// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
EraseONNXEntryPointPass() = default;
EraseONNXEntryPointPass(const EraseONNXEntryPointPass &pass)
: PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>>() {}

void runOnOperation() override {
auto walkResult = getOperation().walk([](ONNXEntryPointOp op) {
op.erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return signalPassFailure();
}
}
};
} // namespace onnx_mlir

// std::unique_ptr<OperationPass<ModuleOp>>
std::unique_ptr<mlir::Pass> onnx_mlir::createEraseONNXEntryPointPass() {
return std::make_unique<EraseONNXEntryPointPass>();
}
49 changes: 49 additions & 0 deletions src/Conversion/ONNXToTorch/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- Elementwise.cpp - Elementwise Ops -------------------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers ONNX element-wise operators to Torch dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;
using namespace mlir::torch;

namespace onnx_mlir {

namespace {

// AtenAddOp requires an additional alpha parameter and thus requires a unique
// lowering
class ConvertONNXAddOp : public OpConversionPattern<ONNXAddOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXAddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
auto newResultType =
getTypeConverter()->convertType(op.getResult().getType());
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
op, newResultType, adaptor.A(), adaptor.B(), one);
return success();
}
};
} // namespace

void populateLoweringONNXElementwiseOpToTorchPattern(
TypeConverter &typeConverter, RewritePatternSet &patterns,
MLIRContext *ctx) {
patterns.add<ConvertONNXAddOp>(typeConverter, ctx);
}

} // namespace onnx_mlir
19 changes: 19 additions & 0 deletions src/Conversion/ONNXToTorch/ONNXToTorchCommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====----- ONNXToTorchCommon.cpp - ONNX dialects to Torch lowering
//---------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ===============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the Torch backend dialect.
//
//===------------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

namespace onnx_mlir {} // namespace onnx_mlir
52 changes: 52 additions & 0 deletions src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ONNXToTorchCommon.hpp - ONNX dialects to Torch lowering
//--------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ===============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the Torch dialect.
//
//===------------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"

#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

#include "src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.hpp"
#include "src/Dialect/Mlir/IndexExpr.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ONNX/ConstPropHelper.hpp"

namespace onnx_mlir {

//===----------------------------------------------------------------------===//
// Functions for populating the conversion patterns for the lowerings.
//===----------------------------------------------------------------------===//

// `Math` directory methods:
void populateLoweringONNXElementwiseOpToTorchPattern(
TypeConverter &, RewritePatternSet &, MLIRContext *);

// `Tensor` directory methods:
void populateLoweringONNXConstantOpToTorchPattern(
TypeConverter &, RewritePatternSet &, MLIRContext *);
} // namespace onnx_mlir
Loading