Skip to content

Commit

Permalink
Do not fuse locations when normalizing constants for Add and Mul (#3016)
Browse files Browse the repository at this point in the history
Signed-off-by: Rickert, Jonas <[email protected]>
  • Loading branch information
jorickert authored Dec 4, 2024
1 parent fb09556 commit bac09ee
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/Dialect/ONNX/Transforms/ConstProp.td
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ def CreateScatterNDOfConst :
// Use commutativity to normalize constants in the second position of Add.
def AddConstCommutative1 : NamedPat<"AddConstCommutative1",
// From add(c, x).
(ONNXAddOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXAddOp:$addOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
// To add(x, c).
(ONNXAddOp $x, $c),
(ONNXAddOp $x, $c, (location $addOp)),
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
[(IsNotAConstant:$x)]>;

Expand Down Expand Up @@ -575,9 +575,9 @@ def SumConstProp : NamedPat<"SumConstProp",
// Use commutativity to normalize constants in the second position of Mul.
def MulConstCommutative1 : NamedPat<"MulConstCommutative1",
// From mul(c, x).
(ONNXMulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXMulOp:$mulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
// To mul(x, c).
(ONNXMulOp $x, $c),
(ONNXMulOp $x, $c, (location $mulOp)),
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
[(IsNotAConstant:$x)]>;

Expand Down
30 changes: 30 additions & 0 deletions test/mlir/onnx/onnx_constprop_locations.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file --mlir-print-debuginfo | FileCheck %s


//===----------------------------------------------------------------------===//
/// Commutative tests

// CHECK-LABEL: @test_add_constant_1_loc
func.func @test_add_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
%0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant")
%1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Add")
"onnx.Return"(%1) : (tensor<3xf32>) -> ()
// CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]])
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_ADD:#.+]])
// CHECK-DAG: [[LOC_CONST]] = loc("Constant")
// CHECK-DAG: [[LOC_ADD]] = loc("Add")
}

// -----

// CHECK-LABEL: @test_mul_constant_1_loc
func.func @test_mul_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
%0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant")
%1 = "onnx.Mul"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Mul")
"onnx.Return"(%1) : (tensor<3xf32>) -> ()
// CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]])
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_MUL:#.+]])
// CHECK-DAG: [[LOC_CONST]] = loc("Constant")
// CHECK-DAG: [[LOC_MUL]] = loc("Mul")
}

0 comments on commit bac09ee

Please sign in to comment.