|
| 1 | +/* |
| 2 | + * SPDX-License-Identifier: Apache-2.0 |
| 3 | + */ |
| 4 | + |
| 5 | +//====------ ConvertONNXToTorch.cpp - ONNX dialects to Torch lowering |
| 6 | +//-------===// |
| 7 | +// |
| 8 | +// Copyright 2019-2022 The IBM Research Authors. |
| 9 | +// |
| 10 | +// =============================================================================== |
| 11 | +// |
| 12 | +// This file implements the lowering of frontend operations to Torch backend IR. |
| 13 | +// |
| 14 | +//===------------------------------------------------------------------------===// |
| 15 | + |
| 16 | +#include "src/Conversion/ONNXToTorch/ONNXToTorchCommon.hpp" |
| 17 | + |
| 18 | +using namespace mlir; |
| 19 | + |
| 20 | +namespace onnx_mlir { |
| 21 | + |
| 22 | +void populateONNXToTorchConversionPattern(TypeConverter &typeConverter, |
| 23 | + RewritePatternSet &patterns, MLIRContext *ctx) { |
| 24 | + // Math |
| 25 | + populateLoweringONNXElementwiseOpToTorchPattern(typeConverter, patterns, ctx); |
| 26 | + populateLoweringONNXConstantOpToTorchPattern(typeConverter, patterns, ctx); |
| 27 | +} |
| 28 | + |
| 29 | +//===----------------------------------------------------------------------===// |
| 30 | +// Frontend to Mhlo Dialect lowering pass |
| 31 | +//===----------------------------------------------------------------------===// |
| 32 | + |
| 33 | +struct FrontendToTorchLoweringPass |
| 34 | + : public PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>> { |
| 35 | + |
| 36 | + StringRef getArgument() const override { return "convert-onnx-to-torch"; } |
| 37 | + |
| 38 | + StringRef getDescription() const override { |
| 39 | + return "Lower frontend ops to Torch dialect."; |
| 40 | + } |
| 41 | + |
| 42 | + // Make sure that we have a valid default constructor and copy |
| 43 | + // constructor to make sure that the options are initialized properly. |
| 44 | + FrontendToTorchLoweringPass() = default; |
| 45 | + FrontendToTorchLoweringPass(const FrontendToTorchLoweringPass &pass) |
| 46 | + : PassWrapper<FrontendToTorchLoweringPass, OperationPass<ModuleOp>>() {} |
| 47 | + |
| 48 | + void runOnOperation() final; |
| 49 | +}; |
| 50 | + |
| 51 | +void FrontendToTorchLoweringPass::runOnOperation() { |
| 52 | + ModuleOp module = getOperation(); |
| 53 | + // The first thing to define is the conversion target. This will define the |
| 54 | + // final target for this lowering. |
| 55 | + ConversionTarget target(getContext()); |
| 56 | + |
| 57 | + // We define the specific operations, or dialects, that are legal targets for |
| 58 | + // this lowering. |
| 59 | + target.addLegalDialect<torch::Torch::TorchDialect, func::FuncDialect>(); |
| 60 | + |
| 61 | + TypeConverter typeConverter; |
| 62 | + typeConverter.addConversion([](Type type) { return type; }); |
| 63 | + onnx_mlir::setupTorchTypeConversion(target, typeConverter); |
| 64 | + |
| 65 | + // Now that the conversion target has been defined, we just need to provide |
| 66 | + // the set of patterns that will lower the frontend operations. |
| 67 | + RewritePatternSet patterns(&getContext()); |
| 68 | + |
| 69 | + // Define patterns. |
| 70 | + populateONNXToTorchConversionPattern(typeConverter, patterns, &getContext()); |
| 71 | + |
| 72 | + // With the target and rewrite patterns defined, we can now attempt the |
| 73 | + // conversion. The conversion will signal failure if any of our `illegal` |
| 74 | + // operations were not converted successfully. |
| 75 | + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { |
| 76 | + signalPassFailure(); |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +std::unique_ptr<Pass> createLowerToTorchPass() { |
| 81 | + return std::make_unique<FrontendToTorchLoweringPass>(); |
| 82 | +} |
| 83 | + |
| 84 | +} // namespace onnx_mlir |
0 commit comments