Skip to content

Commit

Permalink
Add support for ArgMin
Browse files Browse the repository at this point in the history
  • Loading branch information
airMeng committed Sep 26, 2022
1 parent 20667cd commit b6b7b89
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio
| **Add** |14 |No support for short integers. | |
| **And** |7 | | |
| **ArgMax** |13 | | |
| **ArgMin** | |unsupported | |
| **ArgMin** |13 | | |
| **ArrayFeatureExtractor** | |unsupported | |
| **Asin** |7 | | |
| **Asinh** |9 | | |
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_onnx_mlir_library(OMONNXToKrnl
Sequence/SequenceLength.cpp
ConvertONNXToKrnl.cpp
Tensor/ArgMax.cpp
Tensor/ArgMin.cpp
Tensor/Compress.cpp
Tensor/Concat.cpp
Tensor/Constant.cpp
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx);
// Tensor
populateLoweringONNXArgMaxOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXArgMinOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXReshapeOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXPadOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXUnsqueezeOpPattern(patterns, typeConverter, ctx);
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ void populateLoweringONNXSequenceLengthOpPattern(
// `Tensor` directory methods:
void populateLoweringONNXArgMaxOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
void populateLoweringONNXArgMinOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
void populateLoweringONNXUnsqueezeOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
void populateLoweringONNXUnsqueezeV11OpPattern(
Expand Down
147 changes: 147 additions & 0 deletions src/Conversion/ONNXToKrnl/Tensor/ArgMin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ArgMin.cpp - Lowering ArgMin Op -------------------===//
//
// Copyright 2021-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX ArgMin Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"

using namespace mlir;

namespace onnx_mlir {
struct ONNXArgMinOpLowering : public ConversionPattern {
ONNXArgMinOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(
typeConverter, mlir::ONNXArgMinOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Gather info.
auto loc = op->getLoc();
ONNXArgMinOpAdaptor operandAdaptor(operands);
ONNXArgMinOp argMinOp = llvm::cast<ONNXArgMinOp>(op);

// shape helper
ONNXArgMinOpShapeHelper shapeHelper(&argMinOp, &rewriter,
krnl::getDenseElementAttributeFromKrnlValue,
krnl::loadDenseElementArrayValueAtIndex);

auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");

// Convert the reduced output type to MemRefType.
Type convertedType = typeConverter->convertType(*op->result_type_begin());
assert(convertedType && convertedType.isa<MemRefType>() &&
"Failed to convert type to MemRefType");
MemRefType reducedMemRefType = convertedType.cast<MemRefType>();
Type reducedElementType = reducedMemRefType.getElementType();
int64_t reducedRank = reducedMemRefType.getRank();

// data input
auto data = operandAdaptor.data();
auto dataType = data.getType().cast<MemRefType>();
int64_t dataRank = dataType.getRank();

// axis & keepdims attribute
int64_t axis = argMinOp.axis();
assert(axis >= -dataRank && axis <= dataRank - 1);
axis = axis >= 0 ? axis : (dataRank + axis);

int64_t keepdims = argMinOp.keepdims();
bool isKeepdims = (keepdims == 1) ? true : false;

// Get type information
llvm::SmallVector<int64_t, 1> axes;
axes.push_back(axis);
std::map<int64_t, int64_t> outInDimMap =
getReductionMapping(dataType, llvm::makeArrayRef(axes), isKeepdims);

// Insert alloc and dealloc
Value alloc = insertAllocAndDeallocSimple(
rewriter, op, reducedMemRefType, loc, shapeHelper.dimsForOutput());

// Constant Value
MathBuilder createMath(rewriter, loc);
Value minusOne = createMath.constant(reducedElementType, -1);
Value zero = createMath.constant(reducedElementType, 0);
auto zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
KrnlBuilder createKrnl(rewriter, loc);

// 1. Krnl loops to initialize the result.
ValueRange initLoopDef = createKrnl.defineLoops(reducedRank);
SmallVector<IndexExpr, 4> initLbs(reducedRank, LiteralIndexExpr(0));
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs,
shapeHelper.dimsForOutput(0),
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
createKrnl.store(minusOne, alloc, loopInd);
});

// 2. Krnl loop to calculate argmin.
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
ValueRange calcLoopDef = createKrnl.defineLoops(dataRank);
SmallVector<IndexExpr, 4> lbs(dataRank, LiteralIndexExpr(0));
MemRefBoundsIndexCapture dataBounds(data);
SmallVector<IndexExpr, 4> ubs;
dataBounds.getDimList(ubs);
createKrnl.iterateIE(calcLoopDef, calcLoopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
// Handle the operation:
SmallVector<Value, 4> inLoopIVs, outLoopIVs, minLoopIVs;

for (int i = 0; i < dataRank; ++i)
inLoopIVs.push_back(loopInd[i]);

for (int i = 0; i < reducedRank; ++i) {
if (outInDimMap.find(i) != outInDimMap.end())
outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]);
else
outLoopIVs.push_back(zeroIndex);
}

Value next = createKrnl.load(data, inLoopIVs);
Value idx = createKrnl.load(alloc, outLoopIVs);

// if index is less than 0, we should set 0 as initial position
Value lessThanZero = create.math.slt(idx, zero);
idx = create.math.select(lessThanZero, zero, idx);

// induction variables of current min value
for (int i = 0; i < dataRank; ++i) {
if (i != axis)
minLoopIVs.push_back(loopInd[i]);
else
minLoopIVs.push_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), idx));
}
Value minVal = createKrnl.load(data, minLoopIVs);

// if next value is larger than current min value, update index
Value greaterThanMin = create.math.slt(next, minVal);
Value pos = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inLoopIVs[axis]);
idx = create.math.select(greaterThanMin, pos, idx);
createKrnl.store(idx, alloc, outLoopIVs);
});

rewriter.replaceOp(op, alloc);
return success();
}
};

void populateLoweringONNXArgMinOpPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXArgMinOpLowering>(typeConverter, ctx);
}

} // namespace onnx_mlir
8 changes: 7 additions & 1 deletion test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,13 @@ def get_test_models():
"test_argmax_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ArgMin

"test_argmin_no_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_default_axis_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_no_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== Asin
"test_asin_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_asin_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
Expand Down
16 changes: 16 additions & 0 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ func.func @test_default_argmax(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> {

// -----

//===----------------------------------------------------------------------===//
/// Test the default behavior of argmin when no information for the
/// permutation of the axes is provided and when a permutation is provided.
//===----------------------------------------------------------------------===//

func.func @test_default_argmin(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> {
%0 = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<*xi64>
"func.return"(%0) : (tensor<*xi64>) -> ()

// CHECK-LABEL: test_default_argmin
// CHECK: [[RES:%.+]] = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<1x3x4xi64>
// CHECK: return [[RES]] : tensor<1x3x4xi64>
}

// -----

//===----------------------------------------------------------------------===//
/// Test the default behavior of transpose when no information for the
/// permutation of the axes is provided and when a permutation is provided.
Expand Down

0 comments on commit b6b7b89

Please sign in to comment.