diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 6db563df01..2ba720859a 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -195,7 +195,8 @@ void addPassesNNPA(mlir::OwningOpRef &module, // LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;); if (emissionTarget >= EmitONNXIR) { - addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty()); + addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty(), + /*donotScrubDisposableElementsAttr*/ true); pm.addPass(onnx_mlir::createDevicePlacementPass(nnpaLoadDevicePlacementFile, nnpaSaveDevicePlacementFile, nnpaPlacementHeuristic)); } diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index 6a7fae8c2f..10d8d784a9 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -54,7 +54,8 @@ void configurePasses() { !disableSimdOption); } -void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { +void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, + bool donotScrubDisposableElementsAttr) { // This is a transition from previous static passes to full dynamic passes // Static passes are kept and the dynamic pass is added as IF-THEN // with the static iteration. @@ -132,7 +133,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { pm.addPass(mlir::createSymbolDCEPass()); // Replace every DisposableElementsAttr with DenseElementsAttr. - pm.addPass(createScrubDisposablePass()); + if (!donotScrubDisposableElementsAttr) + pm.addPass(createScrubDisposablePass()); // Set onnx_node_name if it is missing. Keep this pass at the end of this // function and just before instrumentation. diff --git a/src/Compiler/CompilerPasses.hpp b/src/Compiler/CompilerPasses.hpp index f0c0499f8f..9a6987cf19 100644 --- a/src/Compiler/CompilerPasses.hpp +++ b/src/Compiler/CompilerPasses.hpp @@ -20,7 +20,8 @@ namespace onnx_mlir { // Configures passes up front based on command line options. void configurePasses(); -void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU); +void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, + bool donotScrubDisposableElementsAttr = false); void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, std::string instrumentSignatureString, std::string ONNXOpsStatFilename); void addKrnlToAffinePasses(mlir::PassManager &pm); diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index f92bada090..c29eabfcc4 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -202,11 +202,7 @@ bool haveSameStaticShape(Value lhs, Value rhs) { /// Test if the input is a splat constant with a negative value or not. bool isNegativeSplatConstant(Value val) { - if (!isDenseONNXConstant(val)) - return false; - ONNXConstantOp constOp = val.getDefiningOp(); - auto valAttr = - llvm::dyn_cast_or_null(constOp.getValueAttr()); + ElementsAttr valAttr = getElementAttributeFromONNXValue(val); if (!valAttr) return false; @@ -238,9 +234,7 @@ bool areAllDimSizes(ValueRange vals) { Type elemTy = mlir::cast(val.getType()).getElementType(); if (!mlir::isa(elemTy)) return false; - ONNXConstantOp constOp = val.getDefiningOp(); - auto valAttr = - llvm::dyn_cast_or_null(constOp.getValueAttr()); + ElementsAttr valAttr = getElementAttributeFromONNXValue(val); if (!valAttr) return false; int64_t v = (*valAttr.getValues().begin()).getSExtValue(); @@ -752,9 +746,7 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { Operation *definingOp = v.getDefiningOp(); if (mlir::isa( mlir::cast(v.getType()).getElementType()) && - isa(definingOp) && - mlir::isa( - mlir::cast(definingOp).getValueAttr())) + isDenseONNXConstant(v)) return true; return false; } @@ -795,10 +787,8 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // A helper function to get an integer constant from a value. int64_t getOneIntegerConstant(Value v) const { - Operation *definingOp = v.getDefiningOp(); - DenseElementsAttr valueAttr = mlir::cast( - mlir::cast(definingOp).getValueAttr()); - return (*valueAttr.getValues().begin()).getSExtValue(); + return onnx_mlir::getScalarValue( + v.getDefiningOp()); } // A helper function to match the pattern of the given operation. It also diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td index 0625bb0274..6b5ed293de 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -81,11 +81,6 @@ def createArrayAttrOfTwoToRankOf : NativeCodeCall< def AttributeIsNotNull : Constraint, "Attribute is not null">; -def IsDenseElementsAttr : - Constraint, - CPred<"mlir::isa(($_self))"> - ]>, "Attribute is not a DenseElementsAttr">; - // Intended to check whether there is at least one not-Null the attributes // However, the current table gen can only support max 4 parameters // Multiple rules are used instead of one rule @@ -197,18 +192,13 @@ def HasSpecifiedConstantShape: Constraint< "Has the specified constant shape">; def IsFromONNXConstantOp: Constraint< - CPred<"llvm::dyn_cast_or_null($0.getDefiningOp())">, + CPred<"onnx_mlir::isDenseONNXConstant($0)">, "Is a value from ONNXConstantOp">; def IsNotFromONNXConstantOp: Constraint< CPred<"!(llvm::dyn_cast_or_null($0.getDefiningOp()))">, "Is a value not from ONNXConstantOp">; -def IsFromONNXConstantOpWithDenseElementsAttr: Constraint< - And<[CPred<" $_self.getDefiningOp() ">, - CPred<" isa(onnx_mlir::getONNXConstantOp($_self).getValueAttr()) "> - ]>, "Value is not a ONNXConstantOp with a DenseElementsAttr">; - def IsNegativeSplatConstant: Constraint< CPred<"onnx_mlir::isNegativeSplatConstant($_self)">, "Is a splat constant with a negative value." @@ -407,8 +397,8 @@ def FuseMulConvNullBiasPattern: Pat< // unchanged operands and attributes. $b, $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides), [(HasNoneType $b), - (IsDenseElementsAttr:$denseAttr), - (IsFromONNXConstantOpWithDenseElementsAttr:$w), + (AttributeIsNotNull:$denseAttr), + (IsFromONNXConstantOp $w), (HaveSameElementType $w, $y), // multiplier and Conv weight must have the same element type. (HasRankGT<1> $w), // rank of $w must be at least 2. (RankXMinusRankYIs<1> $w, $y), // rank($y) must be equal to rank($w)-1. @@ -474,13 +464,13 @@ def SwapCastSlicePattern: Pat< // Canonicalization for ONNXTileOp //===----------------------------------------------------------------------===// -def IsFromONNXConstantOpWithOnesDenseElementsAttr: Constraint< - And<[IsFromONNXConstantOpWithDenseElementsAttr.predicate, +def IsFromONNXConstantOpWithOnes: Constraint< + And<[CPred<"onnx_mlir::isDenseONNXConstant($_self)">, CPred<"::llvm::all_of(" - "mlir::dyn_cast(onnx_mlir::getONNXConstantOp($_self)" - ".getValueAttr()).getValues(), " + "onnx_mlir::getElementAttributeFromONNXValue($_self)" + ".getValues(), " "[](int64_t repeat) { return repeat == 1;})"> - ]>, "Value is not a ONNXConstantOp with a DenseElementsAttr of ones">; + ]>, "Value is not a ONNXConstantOp with an ElementsAttr of ones">; def RemoveIdentityTilePattern: Pat< // Tile with `repeats` of all constant 1's @@ -488,7 +478,7 @@ def RemoveIdentityTilePattern: Pat< // Remove the tile. (replaceWithValue $val), // Check that we have indeed a identity tile pattern. - [(IsFromONNXConstantOpWithOnesDenseElementsAttr:$r), (HaveSameShapedType $val,$result)]>; + [(IsFromONNXConstantOpWithOnes:$r), (HaveSameShapedType $val,$result)]>; //===----------------------------------------------------------------------===// // Canonicalization for ONNXLayoutTransformOp