@@ -41,6 +41,8 @@ include "src/Dialect/ONNX/ONNX.td"
41
41
42
42
// Useful test definitions:
43
43
44
+ def HasOneUse : Constraint<CPred<"$_self.hasOneUse()">, "op has exactly one use">;
45
+
44
46
def IsNoneType : Constraint<CPred<"isa<NoneType>(($_self).getType())">>;
45
47
46
48
def IsIntOrFloatType : Constraint<CPred<"isa<IntegerType, FloatType>(($_self).getType().cast<ShapedType>().getElementType())">>;
@@ -219,7 +221,7 @@ def CreateNonZeroOfConst :
219
221
//===----------------------------------------------------------------------===//
220
222
// Patterns to enable opportunities with elementwise ADD operations.
221
223
//===----------------------------------------------------------------------===//
222
-
224
+
223
225
// Use commutativity to normalize constants in the second position of Add.
224
226
def AddConstCommutative1 : Pat<
225
227
// From add(c, x).
@@ -228,52 +230,52 @@ def AddConstCommutative1 : Pat<
228
230
(ONNXAddOp $x, $c),
229
231
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
230
232
[(IsNotAConstant:$x)]>;
231
-
233
+
232
234
// Use associativity to add constants together.
233
235
def AddConstAssociative1 : Pat<
234
236
// From add(add(x, c1), c2).
235
237
(ONNXAddOp
236
- (ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
238
+ (ONNXAddOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
237
239
(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_)),
238
240
// To add(x, add(c1, c2)).
239
241
(ONNXAddOp
240
242
$x,
241
243
(ONNXAddOp $c1, $c2)),
242
- [(IsNotAConstant:$x)]>;
244
+ [(IsNotAConstant:$x), (HasOneUse:$lhs )]>;
243
245
244
246
def AddConstAssociative2 : Pat<
245
247
// From add(add(x, c), y).
246
248
(ONNXAddOp
247
- (ONNXAddOp $x,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
249
+ (ONNXAddOp:$lhs $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
248
250
$y),
249
251
// To add(add(x, y), c).
250
252
(ONNXAddOp
251
253
(ONNXAddOp $x, $y),
252
254
$c),
253
- [(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
255
+ [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs )]>;
254
256
255
257
def AddConstAssociative3 : Pat<
256
258
// From add(x, add(y, c)).
257
259
(ONNXAddOp
258
260
$x,
259
- (ONNXAddOp $y,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
261
+ (ONNXAddOp:$rhs $y, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
260
262
// To add(add(x, y), c).
261
263
(ONNXAddOp
262
264
(ONNXAddOp $x, $y),
263
265
$c),
264
- [(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
266
+ [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs )]>;
265
267
266
268
def AddConstAssociative4 : Pat<
267
269
// From add(add(x, c1), add(y, c2)).
268
270
(ONNXAddOp
269
- (ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
270
- (ONNXAddOp $y,(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
271
+ (ONNXAddOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
272
+ (ONNXAddOp:$rhs $y, (ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
271
273
// To add(add(x, y), c1+c2).
272
274
(ONNXAddOp
273
275
(ONNXAddOp $x, $y),
274
276
(ONNXAddOp $c1, $c2)),
275
- [(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
276
-
277
+ [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs), (HasOneUse:$rhs )]>;
278
+
277
279
// Constant Propagation for Add
278
280
def AddConstProp : Pat<
279
281
// From add(c1, c2).
@@ -336,7 +338,7 @@ def NegofConst : Pat<
336
338
// To (-c)
337
339
(CreateNegOfConst $negOp, $input),
338
340
[(IsFromDenseONNXConstantOp:$input)]>;
339
-
341
+
340
342
// Change a subtraction of a constant c by an addition of -c. Helpfull to combine
341
343
// with other add optimizations.
342
344
def SubConstToNeg : Pat<
@@ -361,7 +363,7 @@ def ReluofConst : Pat<
361
363
// To relu(c)
362
364
(CreateReluOfConst $reluOp, $input),
363
365
[(IsFromDenseONNXConstantOp:$input)]>;
364
-
366
+
365
367
//===----------------------------------------------------------------------===//
366
368
// Const propagation patterns for variadic elementwise operations.
367
369
//===----------------------------------------------------------------------===//
@@ -409,51 +411,51 @@ def MulConstCommutative1 : Pat<
409
411
(ONNXMulOp $x, $c),
410
412
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
411
413
[(IsNotAConstant:$x)]>;
412
-
414
+
413
415
// Use associativity to mul constants together.
414
416
def MulConstAssociative1 : Pat<
415
417
// From mul(mul(x, c1), c2).
416
418
(ONNXMulOp
417
- (ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
419
+ (ONNXMulOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
418
420
(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_)),
419
421
// To mul(x, mul(c1, c2)).
420
422
(ONNXMulOp
421
423
$x,
422
424
(ONNXMulOp $c1, $c2)),
423
- [(IsNotAConstant:$x)]>;
424
-
425
+ [(IsNotAConstant:$x), (HasOneUse:$lhs )]>;
426
+
425
427
def MulConstAssociative2 : Pat<
426
428
// From mul(mul(x, c), y).
427
429
(ONNXMulOp
428
- (ONNXMulOp $x,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
430
+ (ONNXMulOp:$lhs $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
429
431
$y),
430
432
// To mul(mul(x, y), c).
431
433
(ONNXMulOp
432
434
(ONNXMulOp $x, $y),
433
435
$c),
434
- [(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
436
+ [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs )]>;
435
437
436
438
def MulConstAssociative3 : Pat<
437
439
// From mul(x, mul(y, c)).
438
440
(ONNXMulOp
439
441
$x,
440
- (ONNXMulOp $y,(ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
442
+ (ONNXMulOp:$rhs $y, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_))),
441
443
// To mul(mul(x, y), c).
442
444
(ONNXMulOp
443
445
(ONNXMulOp $x, $y),
444
446
$c),
445
- [(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
447
+ [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs )]>;
446
448
447
449
def MulConstAssociative4 : Pat<
448
450
// From mul(mul(x, c1), mul(y, c2)).
449
451
(ONNXMulOp
450
- (ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
451
- (ONNXMulOp $y,(ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
452
+ (ONNXMulOp:$lhs $x, (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_)),
453
+ (ONNXMulOp:$rhs $y, (ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_))),
452
454
// To mul(mul(x, y), c1+c2).
453
455
(ONNXMulOp
454
456
(ONNXMulOp $x, $y),
455
457
(ONNXMulOp $c1, $c2)),
456
- [(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
458
+ [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs), (HasOneUse:$rhs )]>;
457
459
458
460
// Constant Propagation for Mul
459
461
def MulConstProp : Pat<
@@ -480,7 +482,7 @@ def MulOnesOnRhs : Pat<
480
482
(ValuesHaveSameType $result, $x)
481
483
]>;
482
484
483
- // Constant Propagation for Div
485
+ // Constant Propagation for Div
484
486
def DivConstProp : Pat<
485
487
// From div(c1, c2).
486
488
(ONNXDivOp:$divOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
0 commit comments