Skip to content

Commit 0a64272

Browse files
sorenlassenphilass
andauthored
associative const prop rewrite only 1-use subexpressions (#2461)
* associative const prop rewrite only 1-use subexpressions Signed-off-by: Soren Lassen <[email protected]> * clean up whitespace Signed-off-by: Soren Lassen <[email protected]> --------- Signed-off-by: Soren Lassen <[email protected]> Co-authored-by: Philip Lassen <[email protected]>
1 parent 9694cc4 commit 0a64272

File tree

2 files changed

+49
-28
lines changed

2 files changed

+49
-28
lines changed

src/Transform/ONNX/ConstProp.td

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ include "src/Dialect/ONNX/ONNX.td"
4141

4242
// Useful test definitions:
4343

44+
def HasOneUse : Constraint<CPred<"$_self.hasOneUse()">, "op has exactly one use">;
45+
4446
def IsNoneType : Constraint<CPred<"isa<NoneType>(($_self).getType())">>;
4547

4648
def IsIntOrFloatType : Constraint<CPred<"isa<IntegerType, FloatType>(($_self).getType().cast<ShapedType>().getElementType())">>;
@@ -219,7 +221,7 @@ def CreateNonZeroOfConst :
219221
//===----------------------------------------------------------------------===//
220222
// Patterns to enable opportunities with elementwise ADD operations.
221223
//===----------------------------------------------------------------------===//
222-
224+
223225
// Use commutativity to normalize constants in the second position of Add.
224226
def AddConstCommutative1 : Pat<
225227
// From add(c, x).
@@ -228,52 +230,52 @@ def AddConstCommutative1 : Pat<
228230
(ONNXAddOp $x, $c),
229231
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
230232
[(IsNotAConstant:$x)]>;
231-
233+
232234
// Use associativity to add constants together.
233235
def AddConstAssociative1 : Pat<
234236
// From add(add(x, c1), c2).
235237
(ONNXAddOp
236-
(ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
238+
(ONNXAddOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
237239
(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_)),
238240
// To add(x, add(c1, c2)).
239241
(ONNXAddOp
240242
$x,
241243
(ONNXAddOp $c1, $c2)),
242-
[(IsNotAConstant:$x)]>;
244+
[(IsNotAConstant:$x), (HasOneUse:$lhs)]>;
243245

244246
def AddConstAssociative2 : Pat<
245247
// From add(add(x, c), y).
246248
(ONNXAddOp
247-
(ONNXAddOp $x,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
249+
(ONNXAddOp:$lhs $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
248250
$y),
249251
// To add(add(x, y), c).
250252
(ONNXAddOp
251253
(ONNXAddOp $x, $y),
252254
$c),
253-
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
255+
[(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs)]>;
254256

255257
def AddConstAssociative3 : Pat<
256258
// From add(x, add(y, c)).
257259
(ONNXAddOp
258260
$x,
259-
(ONNXAddOp $y,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
261+
(ONNXAddOp:$rhs $y, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
260262
// To add(add(x, y), c).
261263
(ONNXAddOp
262264
(ONNXAddOp $x, $y),
263265
$c),
264-
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
266+
[(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs)]>;
265267

266268
def AddConstAssociative4 : Pat<
267269
// From add(add(x, c1), add(y, c2)).
268270
(ONNXAddOp
269-
(ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
270-
(ONNXAddOp $y,(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
271+
(ONNXAddOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
272+
(ONNXAddOp:$rhs $y, (ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
271273
// To add(add(x, y), c1+c2).
272274
(ONNXAddOp
273275
(ONNXAddOp $x, $y),
274276
(ONNXAddOp $c1, $c2)),
275-
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
276-
277+
[(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs), (HasOneUse:$rhs)]>;
278+
277279
// Constant Propagation for Add
278280
def AddConstProp : Pat<
279281
// From add(c1, c2).
@@ -336,7 +338,7 @@ def NegofConst : Pat<
336338
// To (-c)
337339
(CreateNegOfConst $negOp, $input),
338340
[(IsFromDenseONNXConstantOp:$input)]>;
339-
341+
340342
// Change a subtraction of a constant c by an addition of -c. Helpfull to combine
341343
// with other add optimizations.
342344
def SubConstToNeg : Pat<
@@ -361,7 +363,7 @@ def ReluofConst : Pat<
361363
// To relu(c)
362364
(CreateReluOfConst $reluOp, $input),
363365
[(IsFromDenseONNXConstantOp:$input)]>;
364-
366+
365367
//===----------------------------------------------------------------------===//
366368
// Const propagation patterns for variadic elementwise operations.
367369
//===----------------------------------------------------------------------===//
@@ -409,51 +411,51 @@ def MulConstCommutative1 : Pat<
409411
(ONNXMulOp $x, $c),
410412
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
411413
[(IsNotAConstant:$x)]>;
412-
414+
413415
// Use associativity to mul constants together.
414416
def MulConstAssociative1 : Pat<
415417
// From mul(mul(x, c1), c2).
416418
(ONNXMulOp
417-
(ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
419+
(ONNXMulOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
418420
(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_)),
419421
// To mul(x, mul(c1, c2)).
420422
(ONNXMulOp
421423
$x,
422424
(ONNXMulOp $c1, $c2)),
423-
[(IsNotAConstant:$x)]>;
424-
425+
[(IsNotAConstant:$x), (HasOneUse:$lhs)]>;
426+
425427
def MulConstAssociative2 : Pat<
426428
// From mul(mul(x, c), y).
427429
(ONNXMulOp
428-
(ONNXMulOp $x,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
430+
(ONNXMulOp:$lhs $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
429431
$y),
430432
// To mul(mul(x, y), c).
431433
(ONNXMulOp
432434
(ONNXMulOp $x, $y),
433435
$c),
434-
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
436+
[(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs)]>;
435437

436438
def MulConstAssociative3 : Pat<
437439
// From mul(x, mul(y, c)).
438440
(ONNXMulOp
439441
$x,
440-
(ONNXMulOp $y,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
442+
(ONNXMulOp:$rhs $y, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
441443
// To mul(mul(x, y), c).
442444
(ONNXMulOp
443445
(ONNXMulOp $x, $y),
444446
$c),
445-
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
447+
[(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs)]>;
446448

447449
def MulConstAssociative4 : Pat<
448450
// From mul(mul(x, c1), mul(y, c2)).
449451
(ONNXMulOp
450-
(ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
451-
(ONNXMulOp $y,(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
452+
(ONNXMulOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
453+
(ONNXMulOp:$rhs $y, (ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
452454
// To mul(mul(x, y), c1+c2).
453455
(ONNXMulOp
454456
(ONNXMulOp $x, $y),
455457
(ONNXMulOp $c1, $c2)),
456-
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
458+
[(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs), (HasOneUse:$rhs)]>;
457459

458460
// Constant Propagation for Mul
459461
def MulConstProp : Pat<
@@ -480,7 +482,7 @@ def MulOnesOnRhs : Pat<
480482
(ValuesHaveSameType $result, $x)
481483
]>;
482484

483-
// Constant Propagation for Div
485+
// Constant Propagation for Div
484486
def DivConstProp : Pat<
485487
// From div(c1, c2).
486488
(ONNXDivOp:$divOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),

test/mlir/onnx/onnx_constprop.mlir

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ func.func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
103103
%1 = onnx.Constant dense<[10, 11, 12]> : tensor<3xi32>
104104
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
105105
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
106-
%4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
107-
"onnx.Return"(%4) : (tensor<3xi32>) -> ()
106+
%4 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
107+
%5 = "onnx.Add"(%3, %4) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
108+
"onnx.Return"(%5) : (tensor<3xi32>) -> ()
108109
// CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32>
109110
// CHECK-DAG: [[CONST1:%.+]] = onnx.Constant dense<[10, 13, 16]> : tensor<3xi32>
110111
// CHECK-DAG: [[ADD1:%.+]] = "onnx.Add"(%arg0, %arg0) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
@@ -132,6 +133,24 @@ func.func @test_add_constant_5(%arg0 : tensor<3xi32>, %arg1: tensor<3xi32>, %arg
132133

133134
// -----
134135

136+
func.func @test_add_const_associative2_2uses(%x: tensor<i32>, %y: tensor<i32>, %z: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
137+
%c = onnx.Constant dense<1> : tensor<i32>
138+
%1 = "onnx.Add"(%x, %c) : (tensor<i32> , tensor<i32>) -> tensor<i32>
139+
%2 = "onnx.Add"(%1, %y) : (tensor<i32> , tensor<i32>) -> tensor<i32>
140+
%3 = "onnx.Add"(%1, %z) : (tensor<i32> , tensor<i32>) -> tensor<i32>
141+
onnx.Return %2, %3 : tensor<i32>, tensor<i32>
142+
// CHECK-LABEL: func.func @test_add_const_associative2_2uses
143+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<i32>, [[PARAM_1_:%.+]]: tensor<i32>, [[PARAM_2_:%.+]]: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
144+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<i32>
145+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
146+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Add"([[VAR_1_]], [[PARAM_1_]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
147+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_1_]], [[PARAM_2_]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
148+
// CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]] : tensor<i32>, tensor<i32>
149+
// CHECK: }
150+
}
151+
152+
// -----
153+
135154
// CHECK-LABEL: @test_add_zeros(%arg0: tensor<3xi32>) -> tensor<3xi32>
136155
func.func @test_add_zeros(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
137156
%0 = onnx.Constant dense<[0, 0, 0]> : tensor<3xi32>

0 commit comments

Comments
 (0)