diff --git a/docs/SupportedONNXOps-cpu.md b/docs/SupportedONNXOps-cpu.md index 02c3df33c6..2efa216733 100644 --- a/docs/SupportedONNXOps-cpu.md +++ b/docs/SupportedONNXOps-cpu.md @@ -19,7 +19,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio | **Add** |14 |No support for short integers. | | | **And** |7 | | | | **ArgMax** |13 | | | -| **ArgMin** | |unsupported | | +| **ArgMin** |13 | | | | **ArrayFeatureExtractor** | |unsupported | | | **Asin** |7 | | | | **Asinh** |9 | | | diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index ddec8b7c22..a9347592f2 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -36,7 +36,7 @@ add_onnx_mlir_library(OMONNXToKrnl Sequence/SequenceInsert.cpp Sequence/SequenceLength.cpp ConvertONNXToKrnl.cpp - Tensor/ArgMax.cpp + Tensor/ArgMinMax.cpp Tensor/Compress.cpp Tensor/Concat.cpp Tensor/Constant.cpp diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 79481cef12..78346fddf5 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -200,7 +200,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, // ObjectDetection populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx); // Tensor - populateLoweringONNXArgMaxOpPattern(patterns, typeConverter, ctx); + populateLoweringONNXArgMinMaxOpPattern(patterns, typeConverter, ctx); populateLoweringONNXDimOpPattern(patterns, typeConverter, ctx); populateLoweringONNXReshapeOpPattern(patterns, typeConverter, ctx); populateLoweringONNXPadOpPattern(patterns, typeConverter, ctx); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 3abcc8501d..4d79ceb97f 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -311,7 +311,7 @@ void populateLoweringONNXSequenceLengthOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); // `Tensor` directory methods: -void populateLoweringONNXArgMaxOpPattern( +void populateLoweringONNXArgMinMaxOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); void populateLoweringONNXDimOpPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ArgMax.cpp b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp similarity index 61% rename from src/Conversion/ONNXToKrnl/Tensor/ArgMax.cpp rename to src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp index b3457cd4bf..79e9300bf4 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ArgMax.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp @@ -2,13 +2,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===---------------- ArgMax.cpp - Lowering ArgMax Op -------------------===// +//===------------ ArgMinMax.cpp - Lowering ArgMin/ArgMax Op ---------------===// // // Copyright 2021-2022 The IBM Research Authors. // // ============================================================================= // -// This file lowers the ONNX ArgMax Operator to Krnl dialect. +// This file lowers the ONNX ArgMin/ArgMax Operator to Krnl dialect. // //===----------------------------------------------------------------------===// @@ -19,26 +19,44 @@ using namespace mlir; namespace onnx_mlir { -struct ONNXArgMaxOpLowering : public ConversionPattern { - ONNXArgMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) - : ConversionPattern( - typeConverter, mlir::ONNXArgMaxOp::getOperationName(), 1, ctx) {} + +template <typename ArgOp> +inline Value getCondition(MultiDialectBuilder<KrnlBuilder, MathBuilder> create, + Value next, Value dstVal); + +template <> +inline Value getCondition<ONNXArgMinOp>( + MultiDialectBuilder<KrnlBuilder, MathBuilder> create, Value next, + Value dstVal) { + return create.math.slt(next, dstVal); +} + +template <> +inline Value getCondition<ONNXArgMaxOp>( + MultiDialectBuilder<KrnlBuilder, MathBuilder> create, Value next, + Value dstVal) { + return create.math.sgt(next, dstVal); +} + +template <typename ArgOp, typename OpShapeHelper> +struct ONNXArgMinMaxOpLowering : public ConversionPattern { + ONNXArgMinMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern(typeConverter, ArgOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { // Gather info. auto loc = op->getLoc(); - ONNXArgMaxOpAdaptor operandAdaptor(operands); - ONNXArgMaxOp argMaxOp = llvm::cast<ONNXArgMaxOp>(op); + IndexExprScope scope(&rewriter, loc); + ArgOp argOp = llvm::cast<ArgOp>(op); - // shape helper - ONNXArgMaxOpShapeHelper shapeHelper(&argMaxOp, &rewriter, + typename ArgOp::Adaptor operandAdaptor(operands); + OpShapeHelper shapeHelper(&argOp, &rewriter, krnl::getDenseElementAttributeFromKrnlValue, krnl::loadDenseElementArrayValueAtIndex); - auto shapecomputed = shapeHelper.computeShape(operandAdaptor); - (void)shapecomputed; - assert(!failed(shapecomputed) && "expected to succeed"); + assert(succeeded(shapecomputed) && "Could not compute output shape"); + DimsExpr outputDims = shapeHelper.dimsForOutput(); // Convert the reduced output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); @@ -49,16 +67,16 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { int64_t reducedRank = reducedMemRefType.getRank(); // data input - auto data = operandAdaptor.data(); - auto dataType = data.getType().cast<MemRefType>(); + Value data = operandAdaptor.data(); + MemRefType dataType = data.getType().cast<MemRefType>(); int64_t dataRank = dataType.getRank(); // axis & keepdims attribute - int64_t axis = argMaxOp.axis(); + int64_t axis = argOp.axis(); assert(axis >= -dataRank && axis <= dataRank - 1); axis = axis >= 0 ? axis : (dataRank + axis); - int64_t keepdims = argMaxOp.keepdims(); + int64_t keepdims = argOp.keepdims(); bool isKeepdims = (keepdims == 1) ? true : false; // Get type information @@ -69,7 +87,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { // Insert alloc and dealloc Value alloc = insertAllocAndDeallocSimple( - rewriter, op, reducedMemRefType, loc, shapeHelper.dimsForOutput()); + rewriter, op, reducedMemRefType, loc, outputDims); // Constant Value MathBuilder createMath(rewriter, loc); @@ -81,14 +99,13 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { // 1. Krnl loops to initialize the result. ValueRange initLoopDef = createKrnl.defineLoops(reducedRank); SmallVector<IndexExpr, 4> initLbs(reducedRank, LiteralIndexExpr(0)); - createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, - shapeHelper.dimsForOutput(0), + createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, outputDims, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { createKrnl.store(minusOne, alloc, loopInd); }); - // 2. Krnl loop to calculate argmax. - MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc); + // 2. Krnl loop to calculate argmin/argmax. + MultiDialectBuilder<KrnlBuilder, MathBuilder> create(createKrnl); ValueRange calcLoopDef = createKrnl.defineLoops(dataRank); SmallVector<IndexExpr, 4> lbs(dataRank, LiteralIndexExpr(0)); MemRefBoundsIndexCapture dataBounds(data); @@ -97,7 +114,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { createKrnl.iterateIE(calcLoopDef, calcLoopDef, lbs, ubs, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { // Handle the operation: - SmallVector<Value, 4> inLoopIVs, outLoopIVs, maxLoopIVs; + SmallVector<Value, 4> inLoopIVs, outLoopIVs, dstLoopIVs; for (int i = 0; i < dataRank; ++i) inLoopIVs.push_back(loopInd[i]); @@ -116,21 +133,21 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { Value lessThanZero = create.math.slt(idx, zero); idx = create.math.select(lessThanZero, zero, idx); - // induction variables of current max value + // induction variables of current min/max value for (int i = 0; i < dataRank; ++i) { if (i != axis) - maxLoopIVs.push_back(loopInd[i]); + dstLoopIVs.push_back(loopInd[i]); else - maxLoopIVs.push_back(rewriter.create<arith::IndexCastOp>( + dstLoopIVs.push_back(rewriter.create<arith::IndexCastOp>( loc, rewriter.getIndexType(), idx)); } - Value maxVal = createKrnl.load(data, maxLoopIVs); + Value dstVal = createKrnl.load(data, dstLoopIVs); - // if next value is larger than current max value, update index - Value greaterThanMax = create.math.sgt(next, maxVal); + // if next value is smaller/larger than current value, update index + Value newDstVal = getCondition<ArgOp>(create, next, dstVal); Value pos = rewriter.create<arith::IndexCastOp>( loc, rewriter.getIntegerType(64), inLoopIVs[axis]); - idx = create.math.select(greaterThanMax, pos, idx); + idx = create.math.select(newDstVal, pos, idx); createKrnl.store(idx, alloc, outLoopIVs); }); @@ -139,9 +156,14 @@ struct ONNXArgMaxOpLowering : public ConversionPattern { } }; -void populateLoweringONNXArgMaxOpPattern(RewritePatternSet &patterns, +void populateLoweringONNXArgMinMaxOpPattern(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx) { - patterns.insert<ONNXArgMaxOpLowering>(typeConverter, ctx); + patterns.insert< + ONNXArgMinMaxOpLowering<mlir::ONNXArgMinOp, ONNXArgMinOpShapeHelper>>( + typeConverter, ctx); + patterns.insert< + ONNXArgMinMaxOpLowering<mlir::ONNXArgMaxOp, ONNXArgMaxOpShapeHelper>>( + typeConverter, ctx); } } // namespace onnx_mlir diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index c9a63969b1..099bd488a1 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -113,8 +113,14 @@ def get_test_models(): "test_argmax_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_argmax_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, - # ArgMin - + # ==OP== ArgMin + "test_argmin_no_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_default_axis_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_no_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + "test_argmin_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + # ==OP== Asin "test_asin_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_asin_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 5600d4de20..11daeccf42 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -59,6 +59,22 @@ func.func @test_default_argmax(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> { // ----- +//===----------------------------------------------------------------------===// +/// Test the default behavior of argmin when no information for the +/// permutation of the axes is provided and when a permutation is provided. +//===----------------------------------------------------------------------===// + +func.func @test_default_argmin(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> { + %0 = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<*xi64> + "func.return"(%0) : (tensor<*xi64>) -> () + + // CHECK-LABEL: test_default_argmin + // CHECK: [[RES:%.+]] = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<1x3x4xi64> + // CHECK: return [[RES]] : tensor<1x3x4xi64> +} + +// ----- + //===----------------------------------------------------------------------===// /// Test the default behavior of transpose when no information for the /// permutation of the axes is provided and when a permutation is provided.