Skip to content

Commit

Permalink
Delay scrubbing disposable elements attrs as long as possible (#3006)
Browse files Browse the repository at this point in the history
* Delay scrubbing disposable elements attribute

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Nov 13, 2024
1 parent cb9a949 commit fa91033
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 38 deletions.
3 changes: 2 additions & 1 deletion src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,

// LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;);
if (emissionTarget >= EmitONNXIR) {
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty());
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty(),
/*donotScrubDisposableElementsAttr*/ true);
pm.addPass(onnx_mlir::createDevicePlacementPass(nnpaLoadDevicePlacementFile,
nnpaSaveDevicePlacementFile, nnpaPlacementHeuristic));
}
Expand Down
6 changes: 4 additions & 2 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ void configurePasses() {
!disableSimdOption);
}

void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
bool donotScrubDisposableElementsAttr) {
// This is a transition from previous static passes to full dynamic passes
// Static passes are kept and the dynamic pass is added as IF-THEN
// with the static iteration.
Expand Down Expand Up @@ -132,7 +133,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
pm.addPass(mlir::createSymbolDCEPass());

// Replace every DisposableElementsAttr with DenseElementsAttr.
pm.addPass(createScrubDisposablePass());
if (!donotScrubDisposableElementsAttr)
pm.addPass(createScrubDisposablePass());

// Set onnx_node_name if it is missing. Keep this pass at the end of this
// function and just before instrumentation.
Expand Down
3 changes: 2 additions & 1 deletion src/Compiler/CompilerPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace onnx_mlir {
// Configures passes up front based on command line options.
void configurePasses();

void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU);
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
bool donotScrubDisposableElementsAttr = false);
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string instrumentSignatureString, std::string ONNXOpsStatFilename);
void addKrnlToAffinePasses(mlir::PassManager &pm);
Expand Down
20 changes: 5 additions & 15 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,7 @@ bool haveSameStaticShape(Value lhs, Value rhs) {

/// Test if the input is a splat constant with a negative value or not.
bool isNegativeSplatConstant(Value val) {
if (!isDenseONNXConstant(val))
return false;
ONNXConstantOp constOp = val.getDefiningOp<ONNXConstantOp>();
auto valAttr =
llvm::dyn_cast_or_null<DenseElementsAttr>(constOp.getValueAttr());
ElementsAttr valAttr = getElementAttributeFromONNXValue(val);
if (!valAttr)
return false;

Expand Down Expand Up @@ -238,9 +234,7 @@ bool areAllDimSizes(ValueRange vals) {
Type elemTy = mlir::cast<ShapedType>(val.getType()).getElementType();
if (!mlir::isa<IntegerType>(elemTy))
return false;
ONNXConstantOp constOp = val.getDefiningOp<ONNXConstantOp>();
auto valAttr =
llvm::dyn_cast_or_null<DenseElementsAttr>(constOp.getValueAttr());
ElementsAttr valAttr = getElementAttributeFromONNXValue(val);
if (!valAttr)
return false;
int64_t v = (*valAttr.getValues<APInt>().begin()).getSExtValue();
Expand Down Expand Up @@ -752,9 +746,7 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern<ONNXLoopOp> {
Operation *definingOp = v.getDefiningOp();
if (mlir::isa<IntegerType>(
mlir::cast<ShapedType>(v.getType()).getElementType()) &&
isa<ONNXConstantOp>(definingOp) &&
mlir::isa<DenseElementsAttr>(
mlir::cast<ONNXConstantOp>(definingOp).getValueAttr()))
isDenseONNXConstant(v))
return true;
return false;
}
Expand Down Expand Up @@ -795,10 +787,8 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern<ONNXLoopOp> {

// A helper function to get an integer constant from a value.
int64_t getOneIntegerConstant(Value v) const {
Operation *definingOp = v.getDefiningOp();
DenseElementsAttr valueAttr = mlir::cast<DenseElementsAttr>(
mlir::cast<ONNXConstantOp>(definingOp).getValueAttr());
return (*valueAttr.getValues<APInt>().begin()).getSExtValue();
return onnx_mlir::getScalarValue<int64_t>(
v.getDefiningOp<ONNXConstantOp>());
}

// A helper function to match the pattern of the given operation. It also
Expand Down
28 changes: 9 additions & 19 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ def createArrayAttrOfTwoToRankOf : NativeCodeCall<
def AttributeIsNotNull :
Constraint<CPred<"($_self)">, "Attribute is not null">;

def IsDenseElementsAttr :
Constraint<And<[CPred<" ($_self) ">,
CPred<"mlir::isa<DenseElementsAttr>(($_self))">
]>, "Attribute is not a DenseElementsAttr">;

// Intended to check whether there is at least one not-Null the attributes
// However, the current table gen can only support max 4 parameters
// Multiple rules are used instead of one rule
Expand Down Expand Up @@ -197,18 +192,13 @@ def HasSpecifiedConstantShape: Constraint<
"Has the specified constant shape">;

def IsFromONNXConstantOp: Constraint<
CPred<"llvm::dyn_cast_or_null<ONNXConstantOp>($0.getDefiningOp())">,
CPred<"onnx_mlir::isDenseONNXConstant($0)">,
"Is a value from ONNXConstantOp">;

def IsNotFromONNXConstantOp: Constraint<
CPred<"!(llvm::dyn_cast_or_null<ONNXConstantOp>($0.getDefiningOp()))">,
"Is a value not from ONNXConstantOp">;

def IsFromONNXConstantOpWithDenseElementsAttr: Constraint<
And<[CPred<" $_self.getDefiningOp<ONNXConstantOp>() ">,
CPred<" isa<DenseElementsAttr>(onnx_mlir::getONNXConstantOp($_self).getValueAttr()) ">
]>, "Value is not a ONNXConstantOp with a DenseElementsAttr">;

def IsNegativeSplatConstant: Constraint<
CPred<"onnx_mlir::isNegativeSplatConstant($_self)">,
"Is a splat constant with a negative value."
Expand Down Expand Up @@ -407,8 +397,8 @@ def FuseMulConvNullBiasPattern: Pat<
// unchanged operands and attributes.
$b, $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides),
[(HasNoneType $b),
(IsDenseElementsAttr:$denseAttr),
(IsFromONNXConstantOpWithDenseElementsAttr:$w),
(AttributeIsNotNull:$denseAttr),
(IsFromONNXConstantOp $w),
(HaveSameElementType $w, $y), // multiplier and Conv weight must have the same element type.
(HasRankGT<1> $w), // rank of $w must be at least 2.
(RankXMinusRankYIs<1> $w, $y), // rank($y) must be equal to rank($w)-1.
Expand Down Expand Up @@ -474,21 +464,21 @@ def SwapCastSlicePattern: Pat<
// Canonicalization for ONNXTileOp
//===----------------------------------------------------------------------===//

def IsFromONNXConstantOpWithOnesDenseElementsAttr: Constraint<
And<[IsFromONNXConstantOpWithDenseElementsAttr.predicate,
def IsFromONNXConstantOpWithOnes: Constraint<
And<[CPred<"onnx_mlir::isDenseONNXConstant($_self)">,
CPred<"::llvm::all_of("
"mlir::dyn_cast<DenseElementsAttr>(onnx_mlir::getONNXConstantOp($_self)"
".getValueAttr()).getValues<int64_t>(), "
"onnx_mlir::getElementAttributeFromONNXValue($_self)"
".getValues<int64_t>(), "
"[](int64_t repeat) { return repeat == 1;})">
]>, "Value is not a ONNXConstantOp with a DenseElementsAttr of ones">;
]>, "Value is not a ONNXConstantOp with an ElementsAttr of ones">;

def RemoveIdentityTilePattern: Pat<
// Tile with `repeats` of all constant 1's
(ONNXTileOp:$result $val, $r),
// Remove the tile.
(replaceWithValue $val),
// Check that we have indeed a identity tile pattern.
[(IsFromONNXConstantOpWithOnesDenseElementsAttr:$r), (HaveSameShapedType $val,$result)]>;
[(IsFromONNXConstantOpWithOnes:$r), (HaveSameShapedType $val,$result)]>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXLayoutTransformOp
Expand Down

0 comments on commit fa91033

Please sign in to comment.