Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experiment] remove flag to disable hybrid pass #2576

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,7 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
onnx_mlir::zhigh::createZHighConstPropagationPass());
// One more call to ONNX shape inference/canonicalization/... to update shape
// if possible.
if (enableONNXHybridPass) {
// For starters only illustrating the new hybrid pass by replacing 3 passes
// here. The plan is to replace most of the passes in addONNXToMLIRPasses.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass());
} else {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass());
// Remove common sub-expressions.
pm.addPass(mlir::createCSEPass());

Expand Down
25 changes: 0 additions & 25 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ std::string march; // common for both
InstrumentStages instrumentStage; // common for both
int onnxConstPropExpansionBound; // common for both
std::vector<std::string> onnxConstPropDisablePatterns; // common for both
bool enableONNXHybridPass; // common for both
std::vector<std::string> functionsToDecompose; // common for both
std::string opsForCall; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
Expand All @@ -59,8 +58,6 @@ std::string instrumentOps; // onnx-mlir only
unsigned instrumentControlBits; // onnx-mlir only
bool instrumentONNXSignature; // onnx-mlir only
std::string ONNXOpStats; // onnx-mlir only
int onnxOpTransformThreshold; // onnx-mlir only
bool onnxOpTransformReport; // onnx-mlir only
bool enableParallel; // onnx-mlir only
bool disableSimdOption; // onnx-mlir only
bool enableSimdDataLayout; // onnx-mlir only
Expand Down Expand Up @@ -173,12 +170,6 @@ static llvm::cl::list<std::string, std::vector<std::string>>
llvm::cl::location(onnxConstPropDisablePatterns),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n"
"Set to 'false' if you want to disable ONNX hybrid pass."),
llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::list<std::string, std::vector<std::string>>
functionsToDecomposeOpt("functions-to-decompose",
llvm::cl::desc("Specify ONNX functions to decompose"),
Expand Down Expand Up @@ -377,22 +368,6 @@ static llvm::cl::opt<std::string, true> ONNXOpStatsOpt("onnx-op-stats",
llvm::cl::location(ONNXOpStats), llvm::cl::init(""),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<int, true> onnxOpTransformThresholdOpt(
"onnx-op-transform-threshold",
llvm::cl::desc(
"Max iteration for dynamic op transform passes (default=3).\n"
"If set to 0, onnxOpTransformPass will be disabled, and\n"
"static iteration will be used"),
llvm::cl::location(onnxOpTransformThreshold), llvm::cl::init(3),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> onnxOpTransformReportOpt(
"onnx-op-transform-report",
llvm::cl::desc(
"Report diagnostic info for ONNX op transform/optimization passes."),
llvm::cl::location(onnxOpTransformReport), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> enableParallelOpt("parallel",
llvm::cl::desc("Enable parallelization (default=false)\n"
"Set to 'true' if you want to enable parallelization."),
Expand Down
3 changes: 0 additions & 3 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ extern std::string march; // common for both
extern InstrumentStages instrumentStage; // common for both
extern int onnxConstPropExpansionBound; // common for both
extern std::vector<std::string> onnxConstPropDisablePatterns; // common for both
extern bool enableONNXHybridPass; // common for both
extern std::vector<std::string> functionsToDecompose; // common for both
extern std::string opsForCall; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
Expand All @@ -101,8 +100,6 @@ extern std::string instrumentOps; // onnx-mlir only
extern unsigned instrumentControlBits; // onnx-mlir only
extern bool instrumentONNXSignature; // onnx-mlir only
extern std::string ONNXOpStats; // onnx-mlir only
extern int onnxOpTransformThreshold; // onnx-mlir only
extern bool onnxOpTransformReport; // onnx-mlir only
extern bool enableParallel; // onnx-mlir only
extern bool disableSimdOption; // onnx-mlir only
extern bool enableSimdDataLayout; // onnx-mlir only
Expand Down
46 changes: 6 additions & 40 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,54 +68,20 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {

// Decompose first. Eliminates some unsupported ops without shape inference.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass());
if (enableONNXHybridPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass());
// Convolution Optimization for CPU: enable when there are no accelerators.
if (targetCPU && enableConvOptPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
enableSimdDataLayout && !disableSimdOption));
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass());
// Convolution Optimization for CPU: enable when there are no accelerators.
if (targetCPU && enableConvOptPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
enableSimdDataLayout && !disableSimdOption));
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createONNXHybridTransformPass());
}
} else {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
// Convolution Optimization for CPU: enable when there are no accelerators.
if (targetCPU && enableConvOptPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
enableSimdDataLayout && !disableSimdOption));
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
if (onnxOpTransformThreshold > 0) {
// Dynamic iterate in ONNXOpTransformPass
pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold,
onnxOpTransformReport, targetCPU,
enableSimdDataLayout && !disableSimdOption, enableConvOptPass));
} else {
// Statically add extra passes
for (int i = 0; i < repeatOnnxTransform; i++) {
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createConstPropONNXToONNXPass());
}
}
}

// Simplify shape-related ops.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());

// One more call to ONNX shape inference/canonicalization/... to update
// shape if possible.
if (enableONNXHybridPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass());
} else {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass());

// Replace ONNXReturnOp with func::ReturnOp.
pm.addPass(onnx_mlir::createStandardFuncReturnPass());
Expand Down
6 changes: 0 additions & 6 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ namespace onnx_mlir {
/// Pass for removing DisposableElementsAttr attributes.
std::unique_ptr<mlir::Pass> createScrubDisposablePass(bool closeAfter = true);

/// Pass for ONNX graph level optimization
std::unique_ptr<mlir::Pass> createONNXOpTransformPass();
std::unique_ptr<mlir::Pass> createONNXOpTransformPass(int threshold,
bool report, bool targetCPU, bool enableSimdDataLayoutOpt,
bool enableConvOptPass);

/// Pass for rewriting inside frontend dialect.
std::unique_ptr<mlir::Pass> createDecomposeONNXToONNXPass(
const std::string &target = "");
Expand Down
4 changes: 0 additions & 4 deletions src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ void registerOMPasses(int optLevel) {
return createScrubDisposablePass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createONNXOpTransformPass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createDecomposeONNXToONNXPass();
});
Expand Down
11 changes: 0 additions & 11 deletions src/Transform/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,6 @@ add_onnx_mlir_library(OMInstrumentONNX
MLIRPass
)

add_onnx_mlir_library(OMOpTransform
ONNXOpTransformPass.cpp

LINK_LIBS PUBLIC
OMONNXOps
MLIRPass
OMONNXRewrite
OMShapeInferencePass
MLIRTransforms
)

add_onnx_mlir_library(OMHybridTransform
ONNXHybridTransformPass.cpp

Expand Down
127 changes: 0 additions & 127 deletions src/Transform/ONNX/ONNXOpTransformPass.cpp

This file was deleted.

Loading