Skip to content

Commit

Permalink
Initialize conversion passes from ONNX to Torch backend and add IR te…
Browse files Browse the repository at this point in the history
…sts for ONNXAddOp and ONNXConstantOp
  • Loading branch information
qedawkins committed Sep 22, 2022
1 parent 8f0fff3 commit 67caa8a
Show file tree
Hide file tree
Showing 24 changed files with 783 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
[submodule "third_party/mlir-hlo"]
path = third_party/mlir-hlo
url = https://github.com/tensorflow/mlir-hlo.git
[submodule "third_party/torch-mlir"]
path = third_party/torch-mlir
url = https://github.com/llvm/torch-mlir
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ project(onnx-mlir)
option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON)
option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF)
option(ONNX_MLIR_ENABLE_MHLO "Enable MHLO support." ON)
option(ONNX_MLIR_ENABLE_TORCH "Enable Torch support." ON)
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)

Expand Down Expand Up @@ -152,6 +153,12 @@ if (ONNX_MLIR_ENABLE_MHLO)
add_subdirectory(third_party/mlir-hlo EXCLUDE_FROM_ALL)
endif()

if (ONNX_MLIR_ENABLE_TORCH)
set(TORCH_MLIR_ENABLE_MHLO OFF)
set(TORCH_MLIR_OOT_BUILD ON)
add_subdirectory(third_party/torch-mlir EXCLUDE_FROM_ALL)
endif()

if (NOT TARGET benchmark)
set(BENCHMARK_USE_BUNDLED_GTEST OFF)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
Expand Down Expand Up @@ -185,6 +192,10 @@ if (ONNX_MLIR_ENABLE_MHLO)
add_compile_definitions(ONNX_MLIR_ENABLE_MHLO)
endif()

if (ONNX_MLIR_ENABLE_TORCH)
add_compile_definitions(ONNX_MLIR_ENABLE_TORCH)
endif()

add_subdirectory(utils)
add_subdirectory(include)
add_subdirectory(src)
Expand Down
5 changes: 4 additions & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ cmake -G Ninja ../llvm \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON
-DLLVM_ENABLE_RTTI=ON \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON

cmake --build . -- ${MAKEFLAGS}
cmake --build . --target check-mlir
```

If building onnx-mlir with `ONNX_MLIR_ENABLE_TORCH=OFF`, then `MLIR_ENABLE_BINDINGS_PYTHON` is not needed.

## ONNX-MLIR (this project)

### Build
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ add_subdirectory(ONNXToTOSA)
if (ONNX_MLIR_ENABLE_MHLO)
add_subdirectory(ONNXToMhlo)
endif()

if (ONNX_MLIR_ENABLE_TORCH)
add_subdirectory(ONNXToTorch)
endif()
26 changes: 26 additions & 0 deletions src/Conversion/ONNXToTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0

# Please keep in alphabetical order.

install(TARGETS
TorchMLIRTorchDialect
TorchMLIRTorchUtils
)

add_onnx_mlir_library(OMONNXToTorch
ConvertONNXToTorch.cpp
ConvertONNXToTorchPipeline.cpp
EraseONNXEntryPoint.cpp
ONNXToTorchCommon.cpp
TypeConversion/TorchTypeConversion.cpp
TypeConversion/TorchTypeConversionPasses.cpp

Math/Elementwise.cpp

Tensor/Constant.cpp

LINK_LIBS PUBLIC
TorchMLIRTorchDialect
TorchMLIRTorchUtils
OMONNXOps
)
84 changes: 84 additions & 0 deletions src/Conversion/ONNXToTorch/ConvertONNXToTorch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ConvertONNXToTorch.cpp - ONNX dialects to Torch lowering
//-------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ===============================================================================
//
// This file implements the lowering of frontend operations to Torch backend IR.
//
//===------------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

void populateONNXToTorchConversionPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
// Math
populateLoweringONNXElementwiseOpToTorchPattern(typeConverter, patterns, ctx);
populateLoweringONNXConstantOpToTorchPattern(typeConverter, patterns, ctx);
}

//===----------------------------------------------------------------------===//
// Frontend to Mhlo Dialect lowering pass
//===----------------------------------------------------------------------===//

struct FrontendToTorchLoweringPass
: public PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>> {

StringRef getArgument() const override { return "convert-onnx-to-torch"; }

StringRef getDescription() const override {
return "Lower frontend ops to Torch dialect.";
}

// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
FrontendToTorchLoweringPass() = default;
FrontendToTorchLoweringPass(const FrontendToTorchLoweringPass &pass)
: PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>>() {}

void runOnOperation() final;
};

void FrontendToTorchLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());

// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target.addLegalDialect<torch::Torch::TorchDialect, func::FuncDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
onnx_mlir::setupTorchTypeConversion(target, typeConverter);

// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the frontend operations.
RewritePatternSet patterns(&getContext());

// Define patterns.
populateONNXToTorchConversionPattern(typeConverter, patterns, &getContext());

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

std::unique_ptr<Pass> createLowerToTorchPass() {
return std::make_unique<FrontendToTorchLoweringPass>();
}

} // namespace onnx_mlir
35 changes: 35 additions & 0 deletions src/Conversion/ONNXToTorch/ConvertONNXToTorchPipeline.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ConvertONNXToTorchPipeline.cpp - ONNX dialects to Torch lowering
//pipeline -------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ================================================================================================
//
// This file registers the pipeline for converting ONNX to Torch Backend IR
//
//===-----------------------------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

void registerONNXFrontendToTorchBackendPasses() {
PassPipelineRegistration<>("convert-onnx-to-torch-pipeline",
"Pipeline converting ONNX to Torch dialect.",
onnx_mlir::createONNXFrontendToTorchBackendPasses);
}

void createONNXFrontendToTorchBackendPasses(OpPassManager &pm) {
pm.addPass(createLowerToTorchPass());
pm.addPass(createFuncTorchTypeConversionPass());
pm.addPass(createFinalizingTorchTypeConversionPass());
pm.addPass(createEraseONNXEntryPointPass());
}

} // namespace onnx_mlir
59 changes: 59 additions & 0 deletions src/Conversion/ONNXToTorch/EraseONNXEntryPoint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====----- EraseModuleInitializer.cpp - ONNX dialects to Torch lowering
//---------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ====================================================================================
//
// This file implements a pass for removing the ONNXEntryPointOp for
// compatibility when converting to Torch backend IR.
//
//===------------------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;
using namespace onnx_mlir;

namespace onnx_mlir {
struct EraseONNXEntryPointPass
: public PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>> {

StringRef getArgument() const override { return "erase-onnx-entry-point"; }

StringRef getDescription() const override {
return "Erase ONNXEntryPointOp.";
}

// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
EraseONNXEntryPointPass() = default;
EraseONNXEntryPointPass(const EraseONNXEntryPointPass &pass)
: PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>>() {}

void runOnOperation() override {
auto walkResult = getOperation().walk([](ONNXEntryPointOp op) {
op.erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return signalPassFailure();
}
}
};
} // namespace onnx_mlir

// std::unique_ptr<OperationPass<ModuleOp>>
std::unique_ptr<mlir::Pass> onnx_mlir::createEraseONNXEntryPointPass() {
return std::make_unique<EraseONNXEntryPointPass>();
}
49 changes: 49 additions & 0 deletions src/Conversion/ONNXToTorch/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- Elementwise.cpp - Elementwise Ops -------------------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers ONNX element-wise operators to Torch dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

using namespace mlir;
using namespace mlir::torch;

namespace onnx_mlir {

namespace {

// AtenAddOp requires an additional alpha parameter and thus requires a unique
// lowering
class ConvertONNXAddOp : public OpConversionPattern<ONNXAddOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXAddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
auto newResultType =
getTypeConverter()->convertType(op.getResult().getType());
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
op, newResultType, adaptor.A(), adaptor.B(), one);
return success();
}
};
} // namespace

void populateLoweringONNXElementwiseOpToTorchPattern(
TypeConverter &typeConverter, RewritePatternSet &patterns,
MLIRContext *ctx) {
patterns.add<ConvertONNXAddOp>(typeConverter, ctx);
}

} // namespace onnx_mlir
19 changes: 19 additions & 0 deletions src/Conversion/ONNXToTorch/ONNXToTorchCommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====----- ONNXToTorchCommon.cpp - ONNX dialects to Torch lowering
//---------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ===============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the Torch backend dialect.
//
//===------------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"

namespace onnx_mlir {} // namespace onnx_mlir
52 changes: 52 additions & 0 deletions src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ONNXToTorchCommon.hpp - ONNX dialects to Torch lowering
//--------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// ===============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the Torch dialect.
//
//===------------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"

#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

#include "src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.hpp"
#include "src/Dialect/Mlir/IndexExpr.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ONNX/ConstPropHelper.hpp"

namespace onnx_mlir {

//===----------------------------------------------------------------------===//
// Functions for populating the conversion patterns for the lowerings.
//===----------------------------------------------------------------------===//

// `Math` directory methods:
void populateLoweringONNXElementwiseOpToTorchPattern(
TypeConverter &, RewritePatternSet &, MLIRContext *);

// `Tensor` directory methods:
void populateLoweringONNXConstantOpToTorchPattern(
TypeConverter &, RewritePatternSet &, MLIRContext *);
} // namespace onnx_mlir
Loading

0 comments on commit 67caa8a

Please sign in to comment.