Skip to content

Commit 687a7d7

Browse files
committed
Initialize conversion passes from ONNX to Torch backend and add IR tests for ONNXAddOp and ONNXConstantOp
Signed-off-by: Quinn Dawkins <[email protected]>
1 parent df104fb commit 687a7d7

24 files changed

+784
-1
lines changed

.gitmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@
1313
[submodule "third_party/mlir-hlo"]
1414
path = third_party/mlir-hlo
1515
url = https://github.com/tensorflow/mlir-hlo.git
16+
[submodule "third_party/torch-mlir"]
17+
path = third_party/torch-mlir
18+
url = https://github.com/llvm/torch-mlir
19+
shallow = true

CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ project(onnx-mlir)
88
option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON)
99
option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF)
1010
option(ONNX_MLIR_ENABLE_MHLO "Enable MHLO support." ON)
11+
option(ONNX_MLIR_ENABLE_TORCH "Enable Torch support." ON)
1112
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
1213
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)
1314

@@ -152,6 +153,12 @@ if (ONNX_MLIR_ENABLE_MHLO)
152153
add_subdirectory(third_party/mlir-hlo EXCLUDE_FROM_ALL)
153154
endif()
154155

156+
if (ONNX_MLIR_ENABLE_TORCH)
157+
set(TORCH_MLIR_ENABLE_MHLO OFF)
158+
set(TORCH_MLIR_OOT_BUILD ON)
159+
add_subdirectory(third_party/torch-mlir EXCLUDE_FROM_ALL)
160+
endif()
161+
155162
if (NOT TARGET benchmark)
156163
set(BENCHMARK_USE_BUNDLED_GTEST OFF)
157164
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
@@ -185,6 +192,10 @@ if (ONNX_MLIR_ENABLE_MHLO)
185192
add_compile_definitions(ONNX_MLIR_ENABLE_MHLO)
186193
endif()
187194

195+
if (ONNX_MLIR_ENABLE_TORCH)
196+
add_compile_definitions(ONNX_MLIR_ENABLE_TORCH)
197+
endif()
198+
188199
add_subdirectory(utils)
189200
add_subdirectory(include)
190201
add_subdirectory(src)

docs/BuildOnLinuxOSX.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@ cmake -G Ninja ../llvm \
2626
-DLLVM_TARGETS_TO_BUILD="host" \
2727
-DCMAKE_BUILD_TYPE=Release \
2828
-DLLVM_ENABLE_ASSERTIONS=ON \
29-
-DLLVM_ENABLE_RTTI=ON
29+
-DLLVM_ENABLE_RTTI=ON \
30+
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
3031

3132
cmake --build . -- ${MAKEFLAGS}
3233
cmake --build . --target check-mlir
3334
```
3435

36+
If building onnx-mlir with `ONNX_MLIR_ENABLE_TORCH=OFF`, then `MLIR_ENABLE_BINDINGS_PYTHON` is not needed.
37+
3538
## ONNX-MLIR (this project)
3639

3740
### Build

src/Conversion/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ add_subdirectory(ONNXToTOSA)
99
if (ONNX_MLIR_ENABLE_MHLO)
1010
add_subdirectory(ONNXToMhlo)
1111
endif()
12+
13+
if (ONNX_MLIR_ENABLE_TORCH)
14+
add_subdirectory(ONNXToTorch)
15+
endif()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Please keep in alphabetical order.
4+
5+
install(TARGETS
6+
TorchMLIRTorchDialect
7+
TorchMLIRTorchUtils
8+
)
9+
10+
add_onnx_mlir_library(OMONNXToTorch
11+
ConvertONNXToTorch.cpp
12+
ConvertONNXToTorchPipeline.cpp
13+
EraseONNXEntryPoint.cpp
14+
ONNXToTorchCommon.cpp
15+
TypeConversion/TorchTypeConversion.cpp
16+
TypeConversion/TorchTypeConversionPasses.cpp
17+
18+
Math/Elementwise.cpp
19+
20+
Tensor/Constant.cpp
21+
22+
LINK_LIBS PUBLIC
23+
TorchMLIRTorchDialect
24+
TorchMLIRTorchUtils
25+
OMONNXOps
26+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//====------ ConvertONNXToTorch.cpp - ONNX dialects to Torch lowering
6+
//-------===//
7+
//
8+
// Copyright 2019-2022 The IBM Research Authors.
9+
//
10+
// ===============================================================================
11+
//
12+
// This file implements the lowering of frontend operations to Torch backend IR.
13+
//
14+
//===------------------------------------------------------------------------===//
15+
16+
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"
17+
18+
using namespace mlir;
19+
20+
namespace onnx_mlir {
21+
22+
void populateONNXToTorchConversionPattern(TypeConverter &typeConverter,
23+
RewritePatternSet &patterns, MLIRContext *ctx) {
24+
// Math
25+
populateLoweringONNXElementwiseOpToTorchPattern(typeConverter, patterns, ctx);
26+
populateLoweringONNXConstantOpToTorchPattern(typeConverter, patterns, ctx);
27+
}
28+
29+
//===----------------------------------------------------------------------===//
30+
// Frontend to Mhlo Dialect lowering pass
31+
//===----------------------------------------------------------------------===//
32+
33+
struct FrontendToTorchLoweringPass
34+
: public PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>> {
35+
36+
StringRef getArgument() const override { return "convert-onnx-to-torch"; }
37+
38+
StringRef getDescription() const override {
39+
return "Lower frontend ops to Torch dialect.";
40+
}
41+
42+
// Make sure that we have a valid default constructor and copy
43+
// constructor to make sure that the options are initialized properly.
44+
FrontendToTorchLoweringPass() = default;
45+
FrontendToTorchLoweringPass(const FrontendToTorchLoweringPass &pass)
46+
: PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>>() {}
47+
48+
void runOnOperation() final;
49+
};
50+
51+
void FrontendToTorchLoweringPass::runOnOperation() {
52+
ModuleOp module = getOperation();
53+
// The first thing to define is the conversion target. This will define the
54+
// final target for this lowering.
55+
ConversionTarget target(getContext());
56+
57+
// We define the specific operations, or dialects, that are legal targets for
58+
// this lowering.
59+
target.addLegalDialect<torch::Torch::TorchDialect, func::FuncDialect>();
60+
61+
TypeConverter typeConverter;
62+
typeConverter.addConversion([](Type type) { return type; });
63+
onnx_mlir::setupTorchTypeConversion(target, typeConverter);
64+
65+
// Now that the conversion target has been defined, we just need to provide
66+
// the set of patterns that will lower the frontend operations.
67+
RewritePatternSet patterns(&getContext());
68+
69+
// Define patterns.
70+
populateONNXToTorchConversionPattern(typeConverter, patterns, &getContext());
71+
72+
// With the target and rewrite patterns defined, we can now attempt the
73+
// conversion. The conversion will signal failure if any of our `illegal`
74+
// operations were not converted successfully.
75+
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
76+
signalPassFailure();
77+
}
78+
}
79+
80+
std::unique_ptr<Pass> createLowerToTorchPass() {
81+
return std::make_unique<FrontendToTorchLoweringPass>();
82+
}
83+
84+
} // namespace onnx_mlir
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//====------ ConvertONNXToTorchPipeline.cpp - ONNX dialects to Torch lowering
6+
// pipeline -------===//
7+
//
8+
// Copyright 2019-2022 The IBM Research Authors.
9+
//
10+
// ================================================================================================
11+
//
12+
// This file registers the pipeline for converting ONNX to Torch Backend IR
13+
//
14+
//===-----------------------------------------------------------------------------------------===//
15+
16+
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"
17+
18+
using namespace mlir;
19+
20+
namespace onnx_mlir {
21+
22+
void registerONNXFrontendToTorchBackendPasses() {
23+
PassPipelineRegistration<>("convert-onnx-to-torch-pipeline",
24+
"Pipeline converting ONNX to Torch dialect.",
25+
onnx_mlir::createONNXFrontendToTorchBackendPasses);
26+
}
27+
28+
void createONNXFrontendToTorchBackendPasses(OpPassManager &pm) {
29+
pm.addPass(createLowerToTorchPass());
30+
pm.addPass(createFuncTorchTypeConversionPass());
31+
pm.addPass(createFinalizingTorchTypeConversionPass());
32+
pm.addPass(createEraseONNXEntryPointPass());
33+
}
34+
35+
} // namespace onnx_mlir
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//====----- EraseModuleInitializer.cpp - ONNX dialects to Torch lowering
6+
//---------===//
7+
//
8+
// Copyright 2019-2022 The IBM Research Authors.
9+
//
10+
// ====================================================================================
11+
//
12+
// This file implements a pass for removing the ONNXEntryPointOp for
13+
// compatibility when converting to Torch backend IR.
14+
//
15+
//===------------------------------------------------------------------------------===//
16+
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/IR/BlockAndValueMapping.h"
19+
#include "mlir/IR/Builders.h"
20+
#include "mlir/IR/BuiltinOps.h"
21+
#include "mlir/Transforms/DialectConversion.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"
24+
25+
using namespace mlir;
26+
using namespace onnx_mlir;
27+
28+
namespace onnx_mlir {
29+
struct EraseONNXEntryPointPass
30+
: public PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>> {
31+
32+
StringRef getArgument() const override { return "erase-onnx-entry-point"; }
33+
34+
StringRef getDescription() const override {
35+
return "Erase ONNXEntryPointOp.";
36+
}
37+
38+
// Make sure that we have a valid default constructor and copy
39+
// constructor to make sure that the options are initialized properly.
40+
EraseONNXEntryPointPass() = default;
41+
EraseONNXEntryPointPass(const EraseONNXEntryPointPass &pass)
42+
: PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>>() {}
43+
44+
void runOnOperation() override {
45+
auto walkResult = getOperation().walk([](ONNXEntryPointOp op) {
46+
op.erase();
47+
return WalkResult::advance();
48+
});
49+
if (walkResult.wasInterrupted()) {
50+
return signalPassFailure();
51+
}
52+
}
53+
};
54+
} // namespace onnx_mlir
55+
56+
// std::unique_ptr<OperationPass<ModuleOp>>
57+
std::unique_ptr<mlir::Pass> onnx_mlir::createEraseONNXEntryPointPass() {
58+
return std::make_unique<EraseONNXEntryPointPass>();
59+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===---------------- Elementwise.cpp - Elementwise Ops -------------------===//
6+
//
7+
// Copyright 2019-2022 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file lowers ONNX element-wise operators to Torch dialect.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"
16+
17+
using namespace mlir;
18+
using namespace mlir::torch;
19+
20+
namespace onnx_mlir {
21+
22+
namespace {
23+
24+
// AtenAddOp requires an additional alpha parameter and thus requires a unique
25+
// lowering
26+
class ConvertONNXAddOp : public OpConversionPattern<ONNXAddOp> {
27+
public:
28+
using OpConversionPattern::OpConversionPattern;
29+
LogicalResult matchAndRewrite(ONNXAddOp op, OpAdaptor adaptor,
30+
ConversionPatternRewriter &rewriter) const override {
31+
Location loc = op.getLoc();
32+
Value one = rewriter.create<Torch::ConstantIntOp>(
33+
loc, rewriter.getI64IntegerAttr(1));
34+
auto newResultType =
35+
getTypeConverter()->convertType(op.getResult().getType());
36+
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
37+
op, newResultType, adaptor.A(), adaptor.B(), one);
38+
return success();
39+
}
40+
};
41+
} // namespace
42+
43+
void populateLoweringONNXElementwiseOpToTorchPattern(
44+
TypeConverter &typeConverter, RewritePatternSet &patterns,
45+
MLIRContext *ctx) {
46+
patterns.add<ConvertONNXAddOp>(typeConverter, ctx);
47+
}
48+
49+
} // namespace onnx_mlir
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//====----- ONNXToTorchCommon.cpp - ONNX dialects to Torch lowering
6+
//---------===//
7+
//
8+
// Copyright 2019-2022 The IBM Research Authors.
9+
//
10+
// ===============================================================================
11+
//
12+
// This file contains common code shared by the functions performing the
13+
// lowering to the Torch backend dialect.
14+
//
15+
//===------------------------------------------------------------------------===//
16+
17+
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp"
18+
19+
namespace onnx_mlir {} // namespace onnx_mlir

0 commit comments

Comments
 (0)