-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initialize conversion passes from ONNX to Torch backend and add IR te…
…sts for ONNXAddOp and ONNXConstantOp Signed-off-by: Quinn Dawkins <[email protected]>
- Loading branch information
Showing
23 changed files
with
781 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Please keep in alphabetical order. | ||
|
||
install(TARGETS | ||
TorchMLIRTorchDialect | ||
TorchMLIRTorchUtils | ||
) | ||
|
||
add_onnx_mlir_library(OMONNXToTorch | ||
ConvertONNXToTorch.cpp | ||
ConvertONNXToTorchPipeline.cpp | ||
EraseONNXEntryPoint.cpp | ||
ONNXToTorchCommon.cpp | ||
TypeConversion/TorchTypeConversion.cpp | ||
TypeConversion/TorchTypeConversionPasses.cpp | ||
|
||
Math/Elementwise.cpp | ||
|
||
Tensor/Constant.cpp | ||
|
||
LINK_LIBS PUBLIC | ||
TorchMLIRTorchDialect | ||
TorchMLIRTorchUtils | ||
OMONNXOps | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====------ ConvertONNXToTorch.cpp - ONNX dialects to Torch lowering | ||
//-------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// =============================================================================== | ||
// | ||
// This file implements the lowering of frontend operations to Torch backend IR. | ||
// | ||
//===------------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
void populateONNXToTorchConversionPattern(TypeConverter &typeConverter, | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
// Math | ||
populateLoweringONNXElementwiseOpToTorchPattern(typeConverter, patterns, ctx); | ||
populateLoweringONNXConstantOpToTorchPattern(typeConverter, patterns, ctx); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Frontend to Mhlo Dialect lowering pass | ||
//===----------------------------------------------------------------------===// | ||
|
||
struct FrontendToTorchLoweringPass | ||
: public PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>> { | ||
|
||
StringRef getArgument() const override { return "convert-onnx-to-torch"; } | ||
|
||
StringRef getDescription() const override { | ||
return "Lower frontend ops to Torch dialect."; | ||
} | ||
|
||
// Make sure that we have a valid default constructor and copy | ||
// constructor to make sure that the options are initialized properly. | ||
FrontendToTorchLoweringPass() = default; | ||
FrontendToTorchLoweringPass(const FrontendToTorchLoweringPass &pass) | ||
: PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>>() {} | ||
|
||
void runOnOperation() final; | ||
}; | ||
|
||
void FrontendToTorchLoweringPass::runOnOperation() { | ||
ModuleOp module = getOperation(); | ||
// The first thing to define is the conversion target. This will define the | ||
// final target for this lowering. | ||
ConversionTarget target(getContext()); | ||
|
||
// We define the specific operations, or dialects, that are legal targets for | ||
// this lowering. | ||
target.addLegalDialect<torch::Torch::TorchDialect, func::FuncDialect>(); | ||
|
||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
onnx_mlir::setupTorchTypeConversion(target, typeConverter); | ||
|
||
// Now that the conversion target has been defined, we just need to provide | ||
// the set of patterns that will lower the frontend operations. | ||
RewritePatternSet patterns(&getContext()); | ||
|
||
// Define patterns. | ||
populateONNXToTorchConversionPattern(typeConverter, patterns, &getContext()); | ||
|
||
// With the target and rewrite patterns defined, we can now attempt the | ||
// conversion. The conversion will signal failure if any of our `illegal` | ||
// operations were not converted successfully. | ||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
std::unique_ptr<Pass> createLowerToTorchPass() { | ||
return std::make_unique<FrontendToTorchLoweringPass>(); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====------ ConvertONNXToTorchPipeline.cpp - ONNX dialects to Torch lowering | ||
// pipeline -------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ================================================================================================ | ||
// | ||
// This file registers the pipeline for converting ONNX to Torch Backend IR | ||
// | ||
//===-----------------------------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
void registerONNXFrontendToTorchBackendPasses() { | ||
PassPipelineRegistration<>("convert-onnx-to-torch-pipeline", | ||
"Pipeline converting ONNX to Torch dialect.", | ||
onnx_mlir::createONNXFrontendToTorchBackendPasses); | ||
} | ||
|
||
void createONNXFrontendToTorchBackendPasses(OpPassManager &pm) { | ||
pm.addPass(createLowerToTorchPass()); | ||
pm.addPass(createFuncTorchTypeConversionPass()); | ||
pm.addPass(createFinalizingTorchTypeConversionPass()); | ||
pm.addPass(createEraseONNXEntryPointPass()); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====----- EraseModuleInitializer.cpp - ONNX dialects to Torch lowering | ||
//---------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ==================================================================================== | ||
// | ||
// This file implements a pass for removing the ONNXEntryPointOp for | ||
// compatibility when converting to Torch backend IR. | ||
// | ||
//===------------------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BlockAndValueMapping.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" | ||
|
||
using namespace mlir; | ||
using namespace onnx_mlir; | ||
|
||
namespace onnx_mlir { | ||
struct EraseONNXEntryPointPass | ||
: public PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>> { | ||
|
||
StringRef getArgument() const override { return "erase-onnx-entry-point"; } | ||
|
||
StringRef getDescription() const override { | ||
return "Erase ONNXEntryPointOp."; | ||
} | ||
|
||
// Make sure that we have a valid default constructor and copy | ||
// constructor to make sure that the options are initialized properly. | ||
EraseONNXEntryPointPass() = default; | ||
EraseONNXEntryPointPass(const EraseONNXEntryPointPass &pass) | ||
: PassWrapper<EraseONNXEntryPointPass, OperationPass<ModuleOp>>() {} | ||
|
||
void runOnOperation() override { | ||
auto walkResult = getOperation().walk([](ONNXEntryPointOp op) { | ||
op.erase(); | ||
return WalkResult::advance(); | ||
}); | ||
if (walkResult.wasInterrupted()) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
} // namespace onnx_mlir | ||
|
||
// std::unique_ptr<OperationPass<ModuleOp>> | ||
std::unique_ptr<mlir::Pass> onnx_mlir::createEraseONNXEntryPointPass() { | ||
return std::make_unique<EraseONNXEntryPointPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===---------------- Elementwise.cpp - Elementwise Ops -------------------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX element-wise operators to Torch dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
// AtenAddOp requires an additional alpha parameter and thus requires a unique | ||
// lowering | ||
class ConvertONNXAddOp : public OpConversionPattern<ONNXAddOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
LogicalResult matchAndRewrite(ONNXAddOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Location loc = op.getLoc(); | ||
Value one = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(1)); | ||
auto newResultType = | ||
getTypeConverter()->convertType(op.getResult().getType()); | ||
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>( | ||
op, newResultType, adaptor.A(), adaptor.B(), one); | ||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void populateLoweringONNXElementwiseOpToTorchPattern( | ||
TypeConverter &typeConverter, RewritePatternSet &patterns, | ||
MLIRContext *ctx) { | ||
patterns.add<ConvertONNXAddOp>(typeConverter, ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====----- ONNXToTorchCommon.cpp - ONNX dialects to Torch lowering | ||
//---------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// =============================================================================== | ||
// | ||
// This file contains common code shared by the functions performing the | ||
// lowering to the Torch backend dialect. | ||
// | ||
//===------------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" | ||
|
||
namespace onnx_mlir {} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====------ ONNXToTorchCommon.hpp - ONNX dialects to Torch lowering | ||
//--------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// =============================================================================== | ||
// | ||
// This file contains common code shared by the functions performing the | ||
// lowering to the Torch dialect. | ||
// | ||
//===------------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "llvm/ADT/ArrayRef.h" | ||
#include "llvm/ADT/Sequence.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
|
||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
|
||
#include "src/Conversion/ONNXToTorch/TypeConversion/TorchTypeConversion.hpp" | ||
#include "src/Dialect/Mlir/IndexExpr.hpp" | ||
#include "src/Dialect/ONNX/DialectBuilder.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps.hpp" | ||
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" | ||
#include "src/Pass/Passes.hpp" | ||
#include "src/Transform/ONNX/ConstPropHelper.hpp" | ||
|
||
namespace onnx_mlir { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Functions for populating the conversion patterns for the lowerings. | ||
//===----------------------------------------------------------------------===// | ||
|
||
// `Math` directory methods: | ||
void populateLoweringONNXElementwiseOpToTorchPattern( | ||
TypeConverter &, RewritePatternSet &, MLIRContext *); | ||
|
||
// `Tensor` directory methods: | ||
void populateLoweringONNXConstantOpToTorchPattern( | ||
TypeConverter &, RewritePatternSet &, MLIRContext *); | ||
} // namespace onnx_mlir |
Oops, something went wrong.