From a76f70c02b68c59a1a1fe764f27f73545313f792 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Mon, 18 Nov 2024 15:50:15 +0000 Subject: [PATCH] Do not fuse locations when normalizing constants for Add and Mul Signed-off-by: Rickert, Jonas --- src/Dialect/ONNX/Transforms/ConstProp.td | 8 +++--- test/mlir/onnx/onnx_constprop_locations.mlir | 30 ++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 test/mlir/onnx/onnx_constprop_locations.mlir diff --git a/src/Dialect/ONNX/Transforms/ConstProp.td b/src/Dialect/ONNX/Transforms/ConstProp.td index 1baef13dad..408d01464e 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.td +++ b/src/Dialect/ONNX/Transforms/ConstProp.td @@ -302,9 +302,9 @@ def CreateScatterNDOfConst : // Use commutativity to normalize constants in the second position of Add. def AddConstCommutative1 : NamedPat<"AddConstCommutative1", // From add(c, x). - (ONNXAddOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), + (ONNXAddOp:$addOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), // To add(x, c). - (ONNXAddOp $x, $c), + (ONNXAddOp $x, $c, (location $addOp)), // To avoid infinite loop, constrain the first arguments to be anything but a constant. [(IsNotAConstant:$x)]>; @@ -575,9 +575,9 @@ def SumConstProp : NamedPat<"SumConstProp", // Use commutativity to normalize constants in the second position of Mul. def MulConstCommutative1 : NamedPat<"MulConstCommutative1", // From mul(c, x). - (ONNXMulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), + (ONNXMulOp:$mulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), // To mul(x, c). - (ONNXMulOp $x, $c), + (ONNXMulOp $x, $c, (location $mulOp)), // To avoid infinite loop, constrain the first arguments to be anything but a constant. [(IsNotAConstant:$x)]>; diff --git a/test/mlir/onnx/onnx_constprop_locations.mlir b/test/mlir/onnx/onnx_constprop_locations.mlir new file mode 100644 index 0000000000..c4124ca182 --- /dev/null +++ b/test/mlir/onnx/onnx_constprop_locations.mlir @@ -0,0 +1,30 @@ +// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file --mlir-print-debuginfo | FileCheck %s + + +//===----------------------------------------------------------------------===// +/// Commutative tests + +// CHECK-LABEL: @test_add_constant_1_loc +func.func @test_add_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> { + %0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant") + %1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Add") + "onnx.Return"(%1) : (tensor<3xf32>) -> () + // CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]]) + // CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_ADD:#.+]]) + // CHECK-DAG: [[LOC_CONST]] = loc("Constant") + // CHECK-DAG: [[LOC_ADD]] = loc("Add") +} + +// ----- + +// CHECK-LABEL: @test_mul_constant_1_loc +func.func @test_mul_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> { + %0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant") + %1 = "onnx.Mul"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Mul") + "onnx.Return"(%1) : (tensor<3xf32>) -> () + // CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]]) + // CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_MUL:#.+]]) + // CHECK-DAG: [[LOC_CONST]] = loc("Constant") + // CHECK-DAG: [[LOC_MUL]] = loc("Mul") +} +