Skip to content

Commit 6d2827e

Browse files
authored
flag -onnx-const-prop-disable-pattern (#2456)
* add flag -onnx-const-prop-disable-pattern * implement disabledPatterns * update lit tests * omit SplitOfConst pattern if disabled Signed-off-by: Soren Lassen <[email protected]>
1 parent 0a64272 commit 6d2827e

File tree

8 files changed

+303
-322
lines changed

8 files changed

+303
-322
lines changed

src/Compiler/CompilerOptions.cpp

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,59 +24,60 @@
2424
namespace onnx_mlir {
2525

2626
// Use external storage for the options so that they are globally accessible
27-
std::string inputFilename; // common for both
28-
std::string outputBaseName; // common for both
29-
std::vector<accel::Accelerator::Kind> maccel; // common for both
30-
OptLevel OptimizationLevel; // common for both
31-
std::string mtriple; // common for both
32-
std::string mcpu; // common for both
33-
std::string march; // common for both
34-
InstrumentStages instrumentStage; // common for both
35-
int onnxConstPropExpansionBound; // common for both
36-
bool enableONNXHybridPass; // common for both
37-
std::vector<std::string> functionsToDecompose; // common for both
38-
EmissionTargetType emissionTarget; // onnx-mlir only
39-
bool invokeOnnxVersionConverter; // onnx-mlir only
40-
bool preserveLocations; // onnx-mlir only
41-
bool printIR; // onnx-mlir only
42-
bool preserveBitcode; // onnx-mlir only
43-
bool preserveLLVMIR; // onnx-mlir only
44-
bool preserveMLIR; // onnx-mlir only
45-
bool useOnnxModelTypes; // onnx-mlir only
46-
int repeatOnnxTransform; // onnx-mlir only
47-
std::string shapeInformation; // onnx-mlir only
48-
ModelSize modelSize; // onnx-mlir only
49-
bool storeConstantsToFile; // onnx-mlir only
50-
float constantsToFileTotalThreshold; // onnx-mlir only
51-
float constantsToFileSingleThreshold; // onnx-mlir only
52-
bool VerboseOutput; // onnx-mlir only
53-
std::vector<std::string> Xopt; // onnx-mlir only
54-
std::vector<std::string> Xllc; // onnx-mlir only
55-
std::string mllvm; // onnx-mlir only
56-
std::string instrumentOps; // onnx-mlir only
57-
unsigned instrumentControlBits; // onnx-mlir only
58-
bool instrumentONNXSignature; // onnx-mlir only
59-
std::string ONNXOpStats; // onnx-mlir only
60-
bool enableMemoryBundling; // onnx-mlir only
61-
int onnxOpTransformThreshold; // onnx-mlir only
62-
bool onnxOpTransformReport; // onnx-mlir only
63-
bool enableParallel; // onnx-mlir only
64-
bool disableSimdOption; // onnx-mlir only
65-
bool enableSimdDataLayout; // onnx-mlir only
66-
bool verifyInputTensors; // onnx-mlir only
67-
bool allowSorting; // onnx-mlir only
68-
std::string reportHeapBefore; // onnx-mlir only
69-
std::string reportHeapAfter; // onnx-mlir only
70-
std::string modelTag; // onnx-mlir only
71-
bool enableConvOptPass; // onnx-mlir only
72-
std::vector<std::string> extraLibPaths; // onnx-mlir only
73-
std::vector<std::string> extraLibs; // onnx-mlir only
74-
ProfileIRs profileIR; // onnx-mlir only
75-
OptReport optReport; // onnx-mlir only
76-
bool split_input_file; // onnx-mlir-opt only
77-
bool verify_diagnostics; // onnx-mlir-opt only
78-
bool verify_passes; // onnx-mlir-opt only
79-
bool allowUnregisteredDialects; // onnx-mlir-opt only
27+
std::string inputFilename; // common for both
28+
std::string outputBaseName; // common for both
29+
std::vector<accel::Accelerator::Kind> maccel; // common for both
30+
OptLevel OptimizationLevel; // common for both
31+
std::string mtriple; // common for both
32+
std::string mcpu; // common for both
33+
std::string march; // common for both
34+
InstrumentStages instrumentStage; // common for both
35+
int onnxConstPropExpansionBound; // common for both
36+
std::vector<std::string> onnxConstPropDisablePatterns; // common for both
37+
bool enableONNXHybridPass; // common for both
38+
std::vector<std::string> functionsToDecompose; // common for both
39+
EmissionTargetType emissionTarget; // onnx-mlir only
40+
bool invokeOnnxVersionConverter; // onnx-mlir only
41+
bool preserveLocations; // onnx-mlir only
42+
bool printIR; // onnx-mlir only
43+
bool preserveBitcode; // onnx-mlir only
44+
bool preserveLLVMIR; // onnx-mlir only
45+
bool preserveMLIR; // onnx-mlir only
46+
bool useOnnxModelTypes; // onnx-mlir only
47+
int repeatOnnxTransform; // onnx-mlir only
48+
std::string shapeInformation; // onnx-mlir only
49+
ModelSize modelSize; // onnx-mlir only
50+
bool storeConstantsToFile; // onnx-mlir only
51+
float constantsToFileTotalThreshold; // onnx-mlir only
52+
float constantsToFileSingleThreshold; // onnx-mlir only
53+
bool VerboseOutput; // onnx-mlir only
54+
std::vector<std::string> Xopt; // onnx-mlir only
55+
std::vector<std::string> Xllc; // onnx-mlir only
56+
std::string mllvm; // onnx-mlir only
57+
std::string instrumentOps; // onnx-mlir only
58+
unsigned instrumentControlBits; // onnx-mlir only
59+
bool instrumentONNXSignature; // onnx-mlir only
60+
std::string ONNXOpStats; // onnx-mlir only
61+
bool enableMemoryBundling; // onnx-mlir only
62+
int onnxOpTransformThreshold; // onnx-mlir only
63+
bool onnxOpTransformReport; // onnx-mlir only
64+
bool enableParallel; // onnx-mlir only
65+
bool disableSimdOption; // onnx-mlir only
66+
bool enableSimdDataLayout; // onnx-mlir only
67+
bool verifyInputTensors; // onnx-mlir only
68+
bool allowSorting; // onnx-mlir only
69+
std::string reportHeapBefore; // onnx-mlir only
70+
std::string reportHeapAfter; // onnx-mlir only
71+
std::string modelTag; // onnx-mlir only
72+
bool enableConvOptPass; // onnx-mlir only
73+
std::vector<std::string> extraLibPaths; // onnx-mlir only
74+
std::vector<std::string> extraLibs; // onnx-mlir only
75+
ProfileIRs profileIR; // onnx-mlir only
76+
OptReport optReport; // onnx-mlir only
77+
bool split_input_file; // onnx-mlir-opt only
78+
bool verify_diagnostics; // onnx-mlir-opt only
79+
bool verify_passes; // onnx-mlir-opt only
80+
bool allowUnregisteredDialects; // onnx-mlir-opt only
8081

8182
// Category for common options shared between onnx-mlir and onnx-mlir-opt.
8283
llvm::cl::OptionCategory OnnxMlirCommonOptions("common options",
@@ -163,6 +164,14 @@ static llvm::cl::opt<int, true> onnxConstPropExpansionBoundOpt(
163164
llvm::cl::location(onnxConstPropExpansionBound), llvm::cl::init(-1),
164165
llvm::cl::cat(OnnxMlirCommonOptions));
165166

167+
static llvm::cl::list<std::string, std::vector<std::string>>
168+
onnxConstPropDisablePatternsOpt("onnx-const-prop-disable-pattern",
169+
llvm::cl::desc("Named constant propagation pattern to disable. "
170+
"Repeat the flag to disable multiple patterns."),
171+
llvm::cl::value_desc("named constant propagation patterns to disable"),
172+
llvm::cl::location(onnxConstPropDisablePatterns),
173+
llvm::cl::cat(OnnxMlirCommonOptions));
174+
166175
static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
167176
llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n"
168177
"Set to 'false' if you want to disable ONNX hybrid pass."),

src/Compiler/CompilerOptions.hpp

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -66,59 +66,60 @@ extern llvm::cl::OptionCategory OnnxMlirOptions;
6666
extern llvm::cl::OptionCategory OnnxMlirOptOptions;
6767

6868
// Options known to onnx-mlir and/or onnx-mlir-opt
69-
extern std::string inputFilename; // common for both
70-
extern std::string outputBaseName; // common for both
71-
extern std::vector<accel::Accelerator::Kind> maccel; // common for both
72-
extern OptLevel OptimizationLevel; // common for both
73-
extern std::string mtriple; // common for both
74-
extern std::string mcpu; // common for both
75-
extern std::string march; // common for both
76-
extern InstrumentStages instrumentStage; // common for both
77-
extern int onnxConstPropExpansionBound; // common for both
78-
extern bool enableONNXHybridPass; // common for both
79-
extern std::vector<std::string> functionsToDecompose; // common for both
80-
extern EmissionTargetType emissionTarget; // onnx-mlir only
81-
extern bool invokeOnnxVersionConverter; // onnx-mlir only
82-
extern bool preserveLocations; // onnx-mlir only
83-
extern bool printIR; // onnx-mlir only
84-
extern bool preserveBitcode; // onnx-mlir only
85-
extern bool preserveLLVMIR; // onnx-mlir only
86-
extern bool preserveMLIR; // onnx-mlir only
87-
extern bool useOnnxModelTypes; // onnx-mlir only
88-
extern int repeatOnnxTransform; // onnx-mlir only
89-
extern std::string shapeInformation; // onnx-mlir only
90-
extern ModelSize modelSize; // onnx-mlir only
91-
extern bool storeConstantsToFile; // onnx-mlir only
92-
extern float constantsToFileTotalThreshold; // onnx-mlir only
93-
extern float constantsToFileSingleThreshold; // onnx-mlir only
94-
extern bool VerboseOutput; // onnx-mlir only
95-
extern std::vector<std::string> Xopt; // onnx-mlir only
96-
extern std::vector<std::string> Xllc; // onnx-mlir only
97-
extern std::string mllvm; // onnx-mlir only
98-
extern std::string instrumentOps; // onnx-mlir only
99-
extern unsigned instrumentControlBits; // onnx-mlir only
100-
extern bool instrumentONNXSignature; // onnx-mlir only
101-
extern std::string ONNXOpStats; // onnx-mlir only
102-
extern bool enableMemoryBundling; // onnx-mlir only
103-
extern int onnxOpTransformThreshold; // onnx-mlir only
104-
extern bool onnxOpTransformReport; // onnx-mlir only
105-
extern bool enableParallel; // onnx-mlir only
106-
extern bool disableSimdOption; // onnx-mlir only
107-
extern bool enableSimdDataLayout; // onnx-mlir only
108-
extern bool verifyInputTensors; // onnx-mlir only
109-
extern bool allowSorting; // onnx-mlir only
110-
extern std::string reportHeapBefore; // onnx-mlir only
111-
extern std::string reportHeapAfter; // onnx-mlir only
112-
extern std::string modelTag; // onnx-mlir only
113-
extern bool enableConvOptPass; // onnx-mlir only
114-
extern std::vector<std::string> extraLibPaths; // onnx-mlir only
115-
extern std::vector<std::string> extraLibs; // onnx-mlir only
116-
extern ProfileIRs profileIR; // onnx-mlir only
117-
extern OptReport optReport; // onnx-mlir only
118-
extern bool split_input_file; // onnx-mlir-opt only
119-
extern bool verify_diagnostics; // onnx-mlir-opt only
120-
extern bool verify_passes; // onnx-mlir-opt only
121-
extern bool allowUnregisteredDialects; // onnx-mlir-opt only
69+
extern std::string inputFilename; // common for both
70+
extern std::string outputBaseName; // common for both
71+
extern std::vector<accel::Accelerator::Kind> maccel; // common for both
72+
extern OptLevel OptimizationLevel; // common for both
73+
extern std::string mtriple; // common for both
74+
extern std::string mcpu; // common for both
75+
extern std::string march; // common for both
76+
extern InstrumentStages instrumentStage; // common for both
77+
extern int onnxConstPropExpansionBound; // common for both
78+
extern std::vector<std::string> onnxConstPropDisablePatterns; // common for both
79+
extern bool enableONNXHybridPass; // common for both
80+
extern std::vector<std::string> functionsToDecompose; // common for both
81+
extern EmissionTargetType emissionTarget; // onnx-mlir only
82+
extern bool invokeOnnxVersionConverter; // onnx-mlir only
83+
extern bool preserveLocations; // onnx-mlir only
84+
extern bool printIR; // onnx-mlir only
85+
extern bool preserveBitcode; // onnx-mlir only
86+
extern bool preserveLLVMIR; // onnx-mlir only
87+
extern bool preserveMLIR; // onnx-mlir only
88+
extern bool useOnnxModelTypes; // onnx-mlir only
89+
extern int repeatOnnxTransform; // onnx-mlir only
90+
extern std::string shapeInformation; // onnx-mlir only
91+
extern ModelSize modelSize; // onnx-mlir only
92+
extern bool storeConstantsToFile; // onnx-mlir only
93+
extern float constantsToFileTotalThreshold; // onnx-mlir only
94+
extern float constantsToFileSingleThreshold; // onnx-mlir only
95+
extern bool VerboseOutput; // onnx-mlir only
96+
extern std::vector<std::string> Xopt; // onnx-mlir only
97+
extern std::vector<std::string> Xllc; // onnx-mlir only
98+
extern std::string mllvm; // onnx-mlir only
99+
extern std::string instrumentOps; // onnx-mlir only
100+
extern unsigned instrumentControlBits; // onnx-mlir only
101+
extern bool instrumentONNXSignature; // onnx-mlir only
102+
extern std::string ONNXOpStats; // onnx-mlir only
103+
extern bool enableMemoryBundling; // onnx-mlir only
104+
extern int onnxOpTransformThreshold; // onnx-mlir only
105+
extern bool onnxOpTransformReport; // onnx-mlir only
106+
extern bool enableParallel; // onnx-mlir only
107+
extern bool disableSimdOption; // onnx-mlir only
108+
extern bool enableSimdDataLayout; // onnx-mlir only
109+
extern bool verifyInputTensors; // onnx-mlir only
110+
extern bool allowSorting; // onnx-mlir only
111+
extern std::string reportHeapBefore; // onnx-mlir only
112+
extern std::string reportHeapAfter; // onnx-mlir only
113+
extern std::string modelTag; // onnx-mlir only
114+
extern bool enableConvOptPass; // onnx-mlir only
115+
extern std::vector<std::string> extraLibPaths; // onnx-mlir only
116+
extern std::vector<std::string> extraLibs; // onnx-mlir only
117+
extern ProfileIRs profileIR; // onnx-mlir only
118+
extern OptReport optReport; // onnx-mlir only
119+
extern bool split_input_file; // onnx-mlir-opt only
120+
extern bool verify_diagnostics; // onnx-mlir-opt only
121+
extern bool verify_passes; // onnx-mlir-opt only
122+
extern bool allowUnregisteredDialects; // onnx-mlir-opt only
122123

123124
extern std::string customEnvFlags;
124125

src/Compiler/CompilerPasses.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ namespace onnx_mlir {
4545
void configurePasses() {
4646
// Set global vector machine support.
4747
VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, "");
48-
configureConstPropONNXToONNXPass(onnxConstPropExpansionBound);
48+
configureConstPropONNXToONNXPass(
49+
onnxConstPropExpansionBound, onnxConstPropDisablePatterns);
4950
configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel,
5051
enableParallel, optReport == OptReport::Simd, !disableSimdOption);
5152
}

src/Pass/Passes.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <memory>
1818
#include <string>
1919

20+
#include "llvm/ADT/ArrayRef.h"
21+
2022
namespace mlir {
2123
class MLIRContext;
2224
class Pass;
@@ -42,7 +44,8 @@ std::unique_ptr<mlir::Pass> createConvOptONNXToONNXPass(
4244
std::unique_ptr<mlir::Pass> createShapeInferencePass();
4345

4446
// To configure ConstPropONNXToONNXPass at program start.
45-
void configureConstPropONNXToONNXPass(int expansionBound);
47+
void configureConstPropONNXToONNXPass(
48+
int expansionBound, llvm::ArrayRef<std::string> disabledPatterns = {});
4649

4750
std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();
4851

0 commit comments

Comments
 (0)