From 6a5483d0f82e08925a387dd6da1017f5d7728d94 Mon Sep 17 00:00:00 2001 From: Hengyu Meng Date: Wed, 28 Sep 2022 18:31:47 -0700 Subject: [PATCH] add support for ArgMin Signed-off-by: Hengyu Meng --- docs/SupportedONNXOps-cpu.md | 2 +- src/Conversion/ONNXToKrnl/CMakeLists.txt | 2 +- .../ONNXToKrnl/ConvertONNXToKrnl.cpp | 2 +- .../ONNXToKrnl/ONNXToKrnlCommon.hpp | 2 +- .../Tensor/{ArgMax.cpp => ArgMinMax.cpp} | 111 +++++++++++++----- test/backend/inference_backend.py | 8 +- test/mlir/onnx/onnx_shape_inference.mlir | 16 +++ 7 files changed, 106 insertions(+), 37 deletions(-) rename src/Conversion/ONNXToKrnl/Tensor/{ArgMax.cpp => ArgMinMax.cpp} (52%) diff --git a/docs/SupportedONNXOps-cpu.md b/docs/SupportedONNXOps-cpu.md index 02c3df33c64..2efa216733e 100644 --- a/docs/SupportedONNXOps-cpu.md +++ b/docs/SupportedONNXOps-cpu.md @@ -19,7 +19,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio | **Add** |14 |No support for short integers. | | | **And** |7 | | | | **ArgMax** |13 | | | -| **ArgMin** | |unsupported | | +| **ArgMin** |13 | | | | **ArrayFeatureExtractor** | |unsupported | | | **Asin** |7 | | | | **Asinh** |9 | | | diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index ddec8b7c226..a9347592f25 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -36,7 +36,7 @@ add_onnx_mlir_library(OMONNXToKrnl Sequence/SequenceInsert.cpp Sequence/SequenceLength.cpp ConvertONNXToKrnl.cpp - Tensor/ArgMax.cpp + Tensor/ArgMinMax.cpp Tensor/Compress.cpp Tensor/Concat.cpp Tensor/Constant.cpp diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 79481cef129..78346fddf5c 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -200,7 +200,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, // ObjectDetection populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx); // Tensor - populateLoweringONNXArgMaxOpPattern(patterns, typeConverter, ctx); + populateLoweringONNXArgMinMaxOpPattern(patterns, typeConverter, ctx); populateLoweringONNXDimOpPattern(patterns, typeConverter, ctx); populateLoweringONNXReshapeOpPattern(patterns, typeConverter, ctx); populateLoweringONNXPadOpPattern(patterns, typeConverter, ctx); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 3abcc8501d6..4d79ceb97f0 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -311,7 +311,7 @@ void populateLoweringONNXSequenceLengthOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); // `Tensor` directory methods: -void populateLoweringONNXArgMaxOpPattern( +void populateLoweringONNXArgMinMaxOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); void populateLoweringONNXDimOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ArgMax.cpp b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp similarity index 52% rename from src/Conversion/ONNXToKrnl/Tensor/ArgMax.cpp rename to src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp index b3457cd4bf1..20c39fb0763 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ArgMax.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp @@ -2,13 +2,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===---------------- ArgMax.cpp - Lowering ArgMax Op -------------------===// +//===------------ ArgMinMax.cpp - Lowering ArgMin/ArgMax Op ---------------===// // // Copyright 2021-2022 The IBM Research Authors. // // ============================================================================= // -// This file lowers the ONNX ArgMax Operator to Krnl dialect. +// This file lowers the ONNX ArgMin/ArgMax Operator to Krnl dialect. // //===----------------------------------------------------------------------===// @@ -19,27 +19,72 @@ using namespace mlir; namespace onnx_mlir { -struct ONNXArgMaxOpLowering : public ConversionPattern { - ONNXArgMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) - : ConversionPattern( - typeConverter, mlir::ONNXArgMaxOp::getOperationName(), 1, ctx) {} + +template +inline Value getCondition(MultiDialectBuilder create, + Value next, Value dstVal); + +template <> +inline Value getCondition( + MultiDialectBuilder create, Value next, + Value dstVal) { + return create.math.slt(next, dstVal); +} + +template <> +inline Value getCondition( + MultiDialectBuilder create, Value next, + Value dstVal) { + return create.math.sgt(next, dstVal); +} + +template +inline llvm::SmallVector getOutputDims(ArgOp *op, + typename ArgOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter, + ArrayValueIndexCapture::GetDenseVal fGetDenseVal, + ArrayValueIndexCapture::LoadVal fLoadVal); + +template <> +inline llvm::SmallVector getOutputDims( + ONNXArgMinOp *op, typename ONNXArgMinOp::Adaptor operandAdaptor, + mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, + ArrayValueIndexCapture::LoadVal fLoadVal) { + ONNXArgMinOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal); + auto shapecomputed = shapeHelper.computeShape(operandAdaptor); + (void)shapecomputed; + assert(!failed(shapecomputed) && "expected to succeed"); + return shapeHelper.dimsForOutput(); +} + +template <> +inline llvm::SmallVector getOutputDims( + ONNXArgMaxOp *op, typename ONNXArgMaxOp::Adaptor operandAdaptor, + mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, + ArrayValueIndexCapture::LoadVal fLoadVal) { + ONNXArgMaxOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal); + auto shapecomputed = shapeHelper.computeShape(operandAdaptor); + (void)shapecomputed; + assert(!failed(shapecomputed) && "expected to succeed"); + return shapeHelper.dimsForOutput(); +} + +template +struct ONNXArgMinMaxOpLowering : public ConversionPattern { + ONNXArgMinMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern(typeConverter, ArgOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Gather info. auto loc = op->getLoc(); - ONNXArgMaxOpAdaptor operandAdaptor(operands); - ONNXArgMaxOp argMaxOp = llvm::cast(op); + IndexExprScope scope(&rewriter, loc); + ArgOp argOp = llvm::cast(op); - // shape helper - ONNXArgMaxOpShapeHelper shapeHelper(&argMaxOp, &rewriter, + typename ArgOp::Adaptor operandAdaptor(operands); + auto OutputDims = getOutputDims(&argOp, operandAdaptor, &rewriter, krnl::getDenseElementAttributeFromKrnlValue, krnl::loadDenseElementArrayValueAtIndex); - auto shapecomputed = shapeHelper.computeShape(operandAdaptor); - (void)shapecomputed; - assert(!failed(shapecomputed) && "expected to succeed"); - // Convert the reduced output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); assert(convertedType && convertedType.isa() && @@ -49,16 +94,16 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { int64_t reducedRank = reducedMemRefType.getRank(); // data input - auto data = operandAdaptor.data(); - auto dataType = data.getType().cast(); + Value data = operandAdaptor.data(); + MemRefType dataType = data.getType().cast(); int64_t dataRank = dataType.getRank(); // axis & keepdims attribute - int64_t axis = argMaxOp.axis(); + int64_t axis = argOp.axis(); assert(axis >= -dataRank && axis <= dataRank - 1); axis = axis >= 0 ? axis : (dataRank + axis); - int64_t keepdims = argMaxOp.keepdims(); + int64_t keepdims = argOp.keepdims(); bool isKeepdims = (keepdims == 1) ? true : false; // Get type information @@ -69,7 +114,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { // Insert alloc and dealloc Value alloc = insertAllocAndDeallocSimple( - rewriter, op, reducedMemRefType, loc, shapeHelper.dimsForOutput()); + rewriter, op, reducedMemRefType, loc, OutputDims); // Constant Value MathBuilder createMath(rewriter, loc); @@ -81,13 +126,12 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { // 1. Krnl loops to initialize the result. ValueRange initLoopDef = createKrnl.defineLoops(reducedRank); SmallVector initLbs(reducedRank, LiteralIndexExpr(0)); - createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, - shapeHelper.dimsForOutput(0), + createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims[0], [&](KrnlBuilder &createKrnl, ValueRange loopInd) { createKrnl.store(minusOne, alloc, loopInd); }); - // 2. Krnl loop to calculate argmax. + // 2. Krnl loop to calculate argmin/argmax. MultiDialectBuilder create(rewriter, loc); ValueRange calcLoopDef = createKrnl.defineLoops(dataRank); SmallVector lbs(dataRank, LiteralIndexExpr(0)); @@ -97,7 +141,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { createKrnl.iterateIE(calcLoopDef, calcLoopDef, lbs, ubs, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { // Handle the operation: - SmallVector inLoopIVs, outLoopIVs, maxLoopIVs; + SmallVector inLoopIVs, outLoopIVs, dstLoopIVs; for (int i = 0; i < dataRank; ++i) inLoopIVs.push_back(loopInd[i]); @@ -116,21 +160,21 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { Value lessThanZero = create.math.slt(idx, zero); idx = create.math.select(lessThanZero, zero, idx); - // induction variables of current max value + // induction variables of current min/max value for (int i = 0; i < dataRank; ++i) { if (i != axis) - maxLoopIVs.push_back(loopInd[i]); + dstLoopIVs.push_back(loopInd[i]); else - maxLoopIVs.push_back(rewriter.create( + dstLoopIVs.push_back(rewriter.create( loc, rewriter.getIndexType(), idx)); } - Value maxVal = createKrnl.load(data, maxLoopIVs); + Value dstVal = createKrnl.load(data, dstLoopIVs); - // if next value is larger than current max value, update index - Value greaterThanMax = create.math.sgt(next, maxVal); + // if next value is smaller/larger than current value, update index + Value newDstVal = getCondition(create, next, dstVal); Value pos = rewriter.create( loc, rewriter.getIntegerType(64), inLoopIVs[axis]); - idx = create.math.select(greaterThanMax, pos, idx); + idx = create.math.select(newDstVal, pos, idx); createKrnl.store(idx, alloc, outLoopIVs); }); @@ -139,9 +183,12 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { } }; -void populateLoweringONNXArgMaxOpPattern(RewritePatternSet &patterns, +void populateLoweringONNXArgMinMaxOpPattern(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); + patterns.insert>( + typeConverter, ctx); + patterns.insert>( + typeConverter, ctx); } } // namespace onnx_mlir diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index c9a63969b18..402cc88ef19 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -114,7 +114,13 @@ def get_test_models(): "test_argmax_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ArgMin - + "test_argmin_no_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_default_axis_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_no_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + # ==OP== Asin "test_asin_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_asin_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 5600d4de200..11daeccf42b 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -59,6 +59,22 @@ func.func @test_default_argmax(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> { // ----- +//===----------------------------------------------------------------------===// +/// Test the default behavior of argmin when no information for the +/// permutation of the axes is provided and when a permutation is provided. +//===----------------------------------------------------------------------===// + +func.func @test_default_argmin(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> { + %0 = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<*xi64> + "func.return"(%0) : (tensor<*xi64>) -> () + + // CHECK-LABEL: test_default_argmin + // CHECK: [[RES:%.+]] = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<1x3x4xi64> + // CHECK: return [[RES]] : tensor<1x3x4xi64> +} + +// ----- + //===----------------------------------------------------------------------===// /// Test the default behavior of transpose when no information for the /// permutation of the axes is provided and when a permutation is provided.