Skip to content

Commit

Permalink
Handle full reduction over all dimensions (#3022)
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Dec 4, 2024
1 parent bac09ee commit 40f5017
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,18 @@ struct ONNXReductionOpLowering : public OpConversionPattern<ONNXReductionOp> {
}
}

//////////////////////////////////////////////////////////////////////
// Reduction over all dimensions to a scalar value.
bool fullReduction = hasNoAxes || (rawAxesIE.size() == inRank);
if (fullReduction && !isKeepdims && enableSIMD) {
Value alloc, none;
if (emitFullSIMDReductionFor<ONNXReductionOp, ONNXNoneOp>(
rewriter, loc, op, input, alloc, none, enableParallel)) {
rewriter.replaceOp(op, alloc);
return success();
}
}

//////////////////////////////////////////////////////////////////////
// Characterize literal axes: make unique and within [0, inRank).
std::vector<int64_t> uniqueLitAxes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,47 @@ func.func @test_reduce_scalar_axes(%arg0: tensor<?x64x?xf32>) -> tensor<?x?xf32>

// -----

// COM: Full reduction over all dimensions to a scalar value.
func.func @test_reduce_all_to_scalar(%arg0: tensor<?x64x?xf32>) -> tensor<*xf32> {
%axes = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.ReduceMax"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<?x64x?xf32>, none) -> tensor<*xf32>
return %0: tensor<*xf32>

// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK-LABEL: func.func @test_reduce_all_to_scalar
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x64x?xf32>) -> memref<f32> {
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32>
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x64x?xf32>
// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<?x64x?xf32>
// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[VAR_dim_0_]] : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[VAR_1_]], [[RES_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_]]) : (memref<?x64x?xf32>, memref<1xindex>) -> memref<?xf32>
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<32xf32>
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<f32>
// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32>
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[VAR_1_]]){
// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_5_]]{{.}} : memref<?xf32>, vector<32xf32>
// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32>
// CHECK: [[VAR_8_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32>
// CHECK: vector.store [[VAR_8_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32>
// CHECK: }
// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32>
// CHECK: [[VAR_4_:%.+]] = vector.reduction <maxnumf>, [[LOAD_RES_1_MEM_1_]] : vector<32xf32> into f32
// CHECK: krnl.store [[VAR_4_]], [[RES_2_]][] : memref<f32>
// CHECK: return [[RES_2_]] : memref<f32>
// CHECK: }
}

// -----

func.func private @test_reducemax_v13(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceMaxV13"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,85 @@

// -----

// COM: Full reduction over all dimensions to a scalar value.
func.func @test_reduce_all_to_scalar(%arg0: tensor<?x64x?xf32>) -> tensor<*xf32> {
%axes = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.ReduceMax"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<?x64x?xf32>, none) -> tensor<*xf32>
return %0: tensor<*xf32>

// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-LABEL: func.func @test_reduce_all_to_scalar
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x64x?xf32>) -> memref<f32> {
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32>
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32>
// CHECK-DAG: [[CST_31_:%.+]] = arith.constant 31 : index
// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index
// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x64x?xf32>
// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<?x64x?xf32>
// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[VAR_dim_1_]] : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[VAR_1_]], [[RES_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_]]) : (memref<?x64x?xf32>, memref<1xindex>) -> memref<?xf32>
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<256xf32>
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<8xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = arith.ceildivsi [[VAR_1_]], [[CST_8_]] : index
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){
// CHECK: [[VAR_7_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[VAR_8_:%.+]] = arith.muli [[VAR_7_]], [[VAR_2_]] : index
// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_2_]] : index
// CHECK: [[VAR_10_:%.+]] = arith.cmpi slt, [[VAR_1_]], [[VAR_9_]] : index
// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[VAR_1_]], [[VAR_9_]] : index
// CHECK-DAG: [[VAR_12_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]])
// CHECK: vector.store [[VAR_cst_0_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32>
// CHECK: [[VAR_13_:%.+]] = arith.subi [[VAR_11_]], [[CST_31_]] : index
// CHECK: scf.for [[I_1_:%.+]] = [[VAR_8_]] to [[VAR_13_]] step [[CST_32_]] {
// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<?xf32>, vector<32xf32>
// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32>
// CHECK: [[VAR_22_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32>
// CHECK: vector.store [[VAR_22_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32>
// CHECK: }
// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_11_]], [[VAR_8_]] : index
// CHECK: [[VAR_15_:%.+]] = arith.remsi [[VAR_14_]], [[CST_32_]] : index
// CHECK: [[VAR_16_:%.+]] = arith.subi [[VAR_14_]], [[VAR_15_]] : index
// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_8_]], [[VAR_16_]] : index
// CHECK: scf.for [[I_2_:%.+]] = [[VAR_17_]] to [[VAR_11_]] step [[CST_1_]] {
// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<?xf32>
// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = memref.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>
// CHECK: [[VAR_22_1_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : f32
// CHECK: memref.store [[VAR_22_1_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>
// CHECK: }
// CHECK: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32>
// CHECK: [[VAR_19_:%.+]] = vector.reduction <maxnumf>, [[LOAD_RES_1_MEM_2_]] : vector<32xf32> into f32
// CHECK: memref.store [[VAR_19_]], [[RES_2_]]{{.}}[[VAR_7_]]{{.}} : memref<8xf32>
// CHECK: }
// CHECK: [[RES_3_:%.+]] = memref.alloc() : memref<f32>
// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32>
// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){
// CHECK: [[VAR_7_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[VAR_8_1_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_7_1_]]{{.}} : memref<8xf32>
// CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>
// CHECK: [[VAR_10_1_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[VAR_8_1_]] : f32
// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>
// CHECK: }
// CHECK: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32>
// CHECK: [[VAR_6_:%.+]] = vector.extract [[LOAD_RES_1_MEM_4_]][0] : f32 from vector<1xf32>
// CHECK: krnl.store [[VAR_6_]], [[RES_3_]][] : memref<f32>
// CHECK: return [[RES_3_]] : memref<f32>
// CHECK: }
}

// -----

// With enable-parallel, a krnl.parallel should be created, which takes a loop (to be parallelized)
// as input. The krnl.parallel should be the last operator before krnl.iterate, since the lowering
// needs to interpret krnl.block, krnl.permute, krnl.unroll first.
Expand Down

0 comments on commit 40f5017

Please sign in to comment.