diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index 2e3892324e..34336d1030 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -662,6 +662,18 @@ struct ONNXReductionOpLowering : public OpConversionPattern { } } + ////////////////////////////////////////////////////////////////////// + // Reduction over all dimensions to a scalar value. + bool fullReduction = hasNoAxes || (rawAxesIE.size() == inRank); + if (fullReduction && !isKeepdims && enableSIMD) { + Value alloc, none; + if (emitFullSIMDReductionFor( + rewriter, loc, op, input, alloc, none, enableParallel)) { + rewriter.replaceOp(op, alloc); + return success(); + } + } + ////////////////////////////////////////////////////////////////////// // Characterize literal axes: make unique and within [0, inRank). std::vector uniqueLitAxes; diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir index 9a18e44b77..82d5e441e5 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir @@ -40,6 +40,47 @@ func.func @test_reduce_scalar_axes(%arg0: tensor) -> tensor // ----- +// COM: Full reduction over all dimensions to a scalar value. +func.func @test_reduce_all_to_scalar(%arg0: tensor) -> tensor<*xf32> { + %axes = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.ReduceMax"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor, none) -> tensor<*xf32> + return %0: tensor<*xf32> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-LABEL: func.func @test_reduce_all_to_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref +// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[VAR_dim_0_]] : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_1_]], [[RES_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[VAR_1_]]){ +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_5_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_8_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_1_]] : vector<32xf32> into f32 +// CHECK: krnl.store [[VAR_4_]], [[RES_2_]][] : memref +// CHECK: return [[RES_2_]] : memref +// CHECK: } +} + +// ----- + func.func private @test_reducemax_v13(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { %0 ="onnx.ReduceMaxV13"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x2x2xf32>)-> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir index 7290d34032..b2cc41276c 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir @@ -2,6 +2,85 @@ // ----- +// COM: Full reduction over all dimensions to a scalar value. +func.func @test_reduce_all_to_scalar(%arg0: tensor) -> tensor<*xf32> { + %axes = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.ReduceMax"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor, none) -> tensor<*xf32> + return %0: tensor<*xf32> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-LABEL: func.func @test_reduce_all_to_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[CST_31_:%.+]] = arith.constant 31 : index +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref +// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[VAR_dim_1_]] : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_1_]], [[RES_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<256xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<8xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.ceildivsi [[VAR_1_]], [[CST_8_]] : index +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ +// CHECK: [[VAR_7_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_8_:%.+]] = arith.muli [[VAR_7_]], [[VAR_2_]] : index +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_2_]] : index +// CHECK: [[VAR_10_:%.+]] = arith.cmpi slt, [[VAR_1_]], [[VAR_9_]] : index +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[VAR_1_]], [[VAR_9_]] : index +// CHECK-DAG: [[VAR_12_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]) +// CHECK: vector.store [[VAR_cst_0_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_13_:%.+]] = arith.subi [[VAR_11_]], [[CST_31_]] : index +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_8_]] to [[VAR_13_]] step [[CST_32_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_22_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_22_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_11_]], [[VAR_8_]] : index +// CHECK: [[VAR_15_:%.+]] = arith.remsi [[VAR_14_]], [[CST_32_]] : index +// CHECK: [[VAR_16_:%.+]] = arith.subi [[VAR_14_]], [[VAR_15_]] : index +// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_8_]], [[VAR_16_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_17_]] to [[VAR_11_]] step [[CST_1_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref +// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = memref.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32> +// CHECK: [[VAR_22_1_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : f32 +// CHECK: memref.store [[VAR_22_1_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_19_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_19_]], [[RES_2_]]{{.}}[[VAR_7_]]{{.}} : memref<8xf32> +// CHECK: } +// CHECK: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ +// CHECK: [[VAR_7_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_8_1_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_7_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: [[VAR_10_1_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[VAR_8_1_]] : f32 +// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: [[VAR_6_:%.+]] = vector.extract [[LOAD_RES_1_MEM_4_]][0] : f32 from vector<1xf32> +// CHECK: krnl.store [[VAR_6_]], [[RES_3_]][] : memref +// CHECK: return [[RES_3_]] : memref +// CHECK: } +} + +// ----- + // With enable-parallel, a krnl.parallel should be created, which takes a loop (to be parallelized) // as input. The krnl.parallel should be the last operator before krnl.iterate, since the lowering // needs to interpret krnl.block, krnl.permute, krnl.unroll first.