Skip to content

Commit

Permalink
add Support for ArgMin
Browse files Browse the repository at this point in the history
Signed-off-by: Hengyu Meng <[email protected]>
  • Loading branch information
airMeng committed Sep 27, 2022
1 parent 6d02941 commit 8dc256e
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 36 deletions.
2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio
| **Add** |14 |No support for short integers. | |
| **And** |7 | | |
| **ArgMax** |13 | | |
| **ArgMin** | |unsupported | |
| **ArgMin** |13 | | |
| **ArrayFeatureExtractor** | |unsupported | |
| **Asin** |7 | | |
| **Asinh** |9 | | |
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ add_onnx_mlir_library(OMONNXToKrnl
Sequence/SequenceInsert.cpp
Sequence/SequenceLength.cpp
ConvertONNXToKrnl.cpp
Tensor/ArgMax.cpp
Tensor/ArgMinMax.cpp
Tensor/Compress.cpp
Tensor/Concat.cpp
Tensor/Constant.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
// ObjectDetection
populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx);
// Tensor
populateLoweringONNXArgMaxOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXArgMinMaxOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXReshapeOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXPadOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXUnsqueezeOpPattern(patterns, typeConverter, ctx);
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ void populateLoweringONNXSequenceLengthOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);

// `Tensor` directory methods:
void populateLoweringONNXArgMaxOpPattern(
void populateLoweringONNXArgMinMaxOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
void populateLoweringONNXUnsqueezeOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ArgMax.cpp - Lowering ArgMax Op -------------------===//
//===---------------- ArgMinMax.cpp - Lowering ArgMax Op -------------------===//
//
// Copyright 2021-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX ArgMax Operator to Krnl dialect.
// This file lowers the ONNX ArgMin/ArgMax Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//

Expand All @@ -19,27 +19,88 @@
using namespace mlir;

namespace onnx_mlir {
struct ONNXArgMaxOpLowering : public ConversionPattern {
ONNXArgMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(
typeConverter, mlir::ONNXArgMaxOp::getOperationName(), 1, ctx) {}

template <typename ArgOp>
inline Value isNewVal(MultiDialectBuilder<KrnlBuilder, MathBuilder> create,
Value next, Value dstVal);

template <>
inline Value isNewVal<ONNXArgMinOp>(
MultiDialectBuilder<KrnlBuilder, MathBuilder> create, Value next,
Value dstVal) {
return create.math.slt(next, dstVal);
}

template <>
inline Value isNewVal<ONNXArgMaxOp>(
MultiDialectBuilder<KrnlBuilder, MathBuilder> create, Value next,
Value dstVal) {
return create.math.sgt(next, dstVal);
}

template <typename ArgOp>
inline MemRefType initInput(typename ArgOp::Adaptor operandAdaptor);

template <>
inline MemRefType initInput<ONNXArgMinOp>(
typename ONNXArgMinOp::Adaptor operandAdaptor) {
Value data = operandAdaptor.data();
return data.getType().cast<MemRefType>();
}

template <>
inline MemRefType initInput<ONNXArgMaxOp>(
typename ONNXArgMaxOp::Adaptor operandAdaptor) {
Value data = operandAdaptor.data();
return data.getType().cast<MemRefType>();
}

template <typename ArgOp>
inline llvm::SmallVector<IndexExpr, 4> getOutputDims(ArgOp *op,
typename ArgOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter,
ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal);

template <>
inline llvm::SmallVector<IndexExpr, 4> getOutputDims<ONNXArgMinOp>(
ONNXArgMinOp *op, typename ONNXArgMinOp::Adaptor operandAdaptor,
mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal) {
ONNXArgMinOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal);
auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");
return shapeHelper.dimsForOutput();
}

template <>
inline llvm::SmallVector<IndexExpr, 4> getOutputDims<ONNXArgMaxOp>(
ONNXArgMaxOp *op, typename ONNXArgMaxOp::Adaptor operandAdaptor,
mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal) {
ONNXArgMaxOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal);
auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");
return shapeHelper.dimsForOutput();
}

template <typename ArgOp>
struct ONNXArgMinMaxOpLowering : public ConversionPattern {
ONNXArgMinMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(typeConverter, ArgOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Gather info.
auto loc = op->getLoc();
ONNXArgMaxOpAdaptor operandAdaptor(operands);
ONNXArgMaxOp argMaxOp = llvm::cast<ONNXArgMaxOp>(op);
ArgOp argOp = llvm::cast<ArgOp>(op);

// shape helper
ONNXArgMaxOpShapeHelper shapeHelper(&argMaxOp, &rewriter,
typename ArgOp::Adaptor operandAdaptor(operands);
auto OutputDims = getOutputDims<ArgOp>(&argOp, operandAdaptor, &rewriter,
krnl::getDenseElementAttributeFromKrnlValue,
krnl::loadDenseElementArrayValueAtIndex);

auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");

// Convert the reduced output type to MemRefType.
Type convertedType = typeConverter->convertType(*op->result_type_begin());
assert(convertedType && convertedType.isa<MemRefType>() &&
Expand All @@ -49,16 +110,16 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
int64_t reducedRank = reducedMemRefType.getRank();

// data input
auto dataType = initInput<ArgOp>(operandAdaptor);
auto data = operandAdaptor.data();
auto dataType = data.getType().cast<MemRefType>();
int64_t dataRank = dataType.getRank();

// axis & keepdims attribute
int64_t axis = argMaxOp.axis();
int64_t axis = argOp.axis();
assert(axis >= -dataRank && axis <= dataRank - 1);
axis = axis >= 0 ? axis : (dataRank + axis);

int64_t keepdims = argMaxOp.keepdims();
int64_t keepdims = argOp.keepdims();
bool isKeepdims = (keepdims == 1) ? true : false;

// Get type information
Expand All @@ -69,7 +130,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {

// Insert alloc and dealloc
Value alloc = insertAllocAndDeallocSimple(
rewriter, op, reducedMemRefType, loc, shapeHelper.dimsForOutput());
rewriter, op, reducedMemRefType, loc, OutputDims);

// Constant Value
MathBuilder createMath(rewriter, loc);
Expand All @@ -81,13 +142,12 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
// 1. Krnl loops to initialize the result.
ValueRange initLoopDef = createKrnl.defineLoops(reducedRank);
SmallVector<IndexExpr, 4> initLbs(reducedRank, LiteralIndexExpr(0));
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs,
shapeHelper.dimsForOutput(0),
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims[0],
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
createKrnl.store(minusOne, alloc, loopInd);
});

// 2. Krnl loop to calculate argmax.
// 2. Krnl loop to calculate argmin/argmax.
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
ValueRange calcLoopDef = createKrnl.defineLoops(dataRank);
SmallVector<IndexExpr, 4> lbs(dataRank, LiteralIndexExpr(0));
Expand All @@ -97,7 +157,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
createKrnl.iterateIE(calcLoopDef, calcLoopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
// Handle the operation:
SmallVector<Value, 4> inLoopIVs, outLoopIVs, maxLoopIVs;
SmallVector<Value, 4> inLoopIVs, outLoopIVs, dstLoopIVs;

for (int i = 0; i < dataRank; ++i)
inLoopIVs.push_back(loopInd[i]);
Expand All @@ -116,21 +176,21 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
Value lessThanZero = create.math.slt(idx, zero);
idx = create.math.select(lessThanZero, zero, idx);

// induction variables of current max value
// induction variables of current min/max value
for (int i = 0; i < dataRank; ++i) {
if (i != axis)
maxLoopIVs.push_back(loopInd[i]);
dstLoopIVs.push_back(loopInd[i]);
else
maxLoopIVs.push_back(rewriter.create<arith::IndexCastOp>(
dstLoopIVs.push_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), idx));
}
Value maxVal = createKrnl.load(data, maxLoopIVs);
Value dstVal = createKrnl.load(data, dstLoopIVs);

// if next value is larger than current max value, update index
Value greaterThanMax = create.math.sgt(next, maxVal);
// if next value is larger than current min/max value, update index
Value newDstVal = isNewVal<ArgOp>(create, next, dstVal);
Value pos = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inLoopIVs[axis]);
idx = create.math.select(greaterThanMax, pos, idx);
idx = create.math.select(newDstVal, pos, idx);
createKrnl.store(idx, alloc, outLoopIVs);
});

Expand All @@ -139,9 +199,12 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
}
};

void populateLoweringONNXArgMaxOpPattern(RewritePatternSet &patterns,
void populateLoweringONNXArgMinMaxOpPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXArgMaxOpLowering>(typeConverter, ctx);
patterns.insert<ONNXArgMinMaxOpLowering<mlir::ONNXArgMinOp>>(
typeConverter, ctx);
patterns.insert<ONNXArgMinMaxOpLowering<mlir::ONNXArgMaxOp>>(
typeConverter, ctx);
}

} // namespace onnx_mlir
8 changes: 7 additions & 1 deletion test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,13 @@ def get_test_models():
"test_argmax_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ArgMin

"test_argmin_no_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_default_axis_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_no_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== Asin
"test_asin_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_asin_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
Expand Down
16 changes: 16 additions & 0 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ func.func @test_default_argmax(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> {

// -----

//===----------------------------------------------------------------------===//
/// Test the default behavior of argmin when no information for the
/// permutation of the axes is provided and when a permutation is provided.
//===----------------------------------------------------------------------===//

func.func @test_default_argmin(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> {
%0 = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<*xi64>
"func.return"(%0) : (tensor<*xi64>) -> ()

// CHECK-LABEL: test_default_argmin
// CHECK: [[RES:%.+]] = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<1x3x4xi64>
// CHECK: return [[RES]] : tensor<1x3x4xi64>
}

// -----

//===----------------------------------------------------------------------===//
/// Test the default behavior of transpose when no information for the
/// permutation of the axes is provided and when a permutation is provided.
Expand Down

0 comments on commit 8dc256e

Please sign in to comment.