Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ArgMin #1737

Merged
merged 4 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio
| **Add** |14 |No support for short integers. | |
| **And** |7 | | |
| **ArgMax** |13 | | |
| **ArgMin** | |unsupported | |
| **ArgMin** |13 | | |
| **ArrayFeatureExtractor** | |unsupported | |
| **Asin** |7 | | |
| **Asinh** |9 | | |
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ add_onnx_mlir_library(OMONNXToKrnl
Sequence/SequenceInsert.cpp
Sequence/SequenceLength.cpp
ConvertONNXToKrnl.cpp
Tensor/ArgMax.cpp
Tensor/ArgMinMax.cpp
Tensor/Compress.cpp
Tensor/Concat.cpp
Tensor/Constant.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
// ObjectDetection
populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx);
// Tensor
populateLoweringONNXArgMaxOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXArgMinMaxOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXDimOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXReshapeOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXPadOpPattern(patterns, typeConverter, ctx);
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ void populateLoweringONNXSequenceLengthOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);

// `Tensor` directory methods:
void populateLoweringONNXArgMaxOpPattern(
void populateLoweringONNXArgMinMaxOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
void populateLoweringONNXDimOpPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ArgMax.cpp - Lowering ArgMax Op -------------------===//
//===------------ ArgMinMax.cpp - Lowering ArgMin/ArgMax Op ---------------===//
//
// Copyright 2021-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX ArgMax Operator to Krnl dialect.
// This file lowers the ONNX ArgMin/ArgMax Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//

Expand All @@ -19,26 +19,44 @@
using namespace mlir;

namespace onnx_mlir {
struct ONNXArgMaxOpLowering : public ConversionPattern {
ONNXArgMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(
typeConverter, mlir::ONNXArgMaxOp::getOperationName(), 1, ctx) {}

template <typename ArgOp>
inline Value getCondition(MultiDialectBuilder<KrnlBuilder, MathBuilder> create,
Value next, Value dstVal);

template <>
inline Value getCondition<ONNXArgMinOp>(
MultiDialectBuilder<KrnlBuilder, MathBuilder> create, Value next,
Value dstVal) {
return create.math.slt(next, dstVal);
}

template <>
inline Value getCondition<ONNXArgMaxOp>(
MultiDialectBuilder<KrnlBuilder, MathBuilder> create, Value next,
Value dstVal) {
return create.math.sgt(next, dstVal);
}

template <typename ArgOp, typename OpShapeHelper>
struct ONNXArgMinMaxOpLowering : public ConversionPattern {
ONNXArgMinMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(typeConverter, ArgOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Gather info.
auto loc = op->getLoc();
ONNXArgMaxOpAdaptor operandAdaptor(operands);
ONNXArgMaxOp argMaxOp = llvm::cast<ONNXArgMaxOp>(op);
IndexExprScope scope(&rewriter, loc);
ArgOp argOp = llvm::cast<ArgOp>(op);

airMeng marked this conversation as resolved.
Show resolved Hide resolved
// shape helper
ONNXArgMaxOpShapeHelper shapeHelper(&argMaxOp, &rewriter,
typename ArgOp::Adaptor operandAdaptor(operands);
OpShapeHelper shapeHelper(&argOp, &rewriter,
krnl::getDenseElementAttributeFromKrnlValue,
krnl::loadDenseElementArrayValueAtIndex);

auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");
assert(succeeded(shapecomputed) && "Could not compute output shape");
DimsExpr outputDims = shapeHelper.dimsForOutput();

// Convert the reduced output type to MemRefType.
Type convertedType = typeConverter->convertType(*op->result_type_begin());
Expand All @@ -49,16 +67,16 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
int64_t reducedRank = reducedMemRefType.getRank();

// data input
auto data = operandAdaptor.data();
auto dataType = data.getType().cast<MemRefType>();
Value data = operandAdaptor.data();
MemRefType dataType = data.getType().cast<MemRefType>();
int64_t dataRank = dataType.getRank();

// axis & keepdims attribute
int64_t axis = argMaxOp.axis();
int64_t axis = argOp.axis();
assert(axis >= -dataRank && axis <= dataRank - 1);
axis = axis >= 0 ? axis : (dataRank + axis);

int64_t keepdims = argMaxOp.keepdims();
int64_t keepdims = argOp.keepdims();
bool isKeepdims = (keepdims == 1) ? true : false;

// Get type information
Expand All @@ -69,7 +87,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {

// Insert alloc and dealloc
Value alloc = insertAllocAndDeallocSimple(
rewriter, op, reducedMemRefType, loc, shapeHelper.dimsForOutput());
rewriter, op, reducedMemRefType, loc, outputDims);

// Constant Value
MathBuilder createMath(rewriter, loc);
Expand All @@ -81,14 +99,13 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
// 1. Krnl loops to initialize the result.
ValueRange initLoopDef = createKrnl.defineLoops(reducedRank);
SmallVector<IndexExpr, 4> initLbs(reducedRank, LiteralIndexExpr(0));
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs,
shapeHelper.dimsForOutput(0),
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, outputDims,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
createKrnl.store(minusOne, alloc, loopInd);
});

// 2. Krnl loop to calculate argmax.
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
// 2. Krnl loop to calculate argmin/argmax.
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(createKrnl);
ValueRange calcLoopDef = createKrnl.defineLoops(dataRank);
SmallVector<IndexExpr, 4> lbs(dataRank, LiteralIndexExpr(0));
MemRefBoundsIndexCapture dataBounds(data);
Expand All @@ -97,7 +114,7 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
createKrnl.iterateIE(calcLoopDef, calcLoopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
// Handle the operation:
SmallVector<Value, 4> inLoopIVs, outLoopIVs, maxLoopIVs;
SmallVector<Value, 4> inLoopIVs, outLoopIVs, dstLoopIVs;

for (int i = 0; i < dataRank; ++i)
inLoopIVs.push_back(loopInd[i]);
Expand All @@ -116,21 +133,21 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
Value lessThanZero = create.math.slt(idx, zero);
idx = create.math.select(lessThanZero, zero, idx);

// induction variables of current max value
// induction variables of current min/max value
for (int i = 0; i < dataRank; ++i) {
if (i != axis)
maxLoopIVs.push_back(loopInd[i]);
dstLoopIVs.push_back(loopInd[i]);
else
maxLoopIVs.push_back(rewriter.create<arith::IndexCastOp>(
dstLoopIVs.push_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), idx));
}
Value maxVal = createKrnl.load(data, maxLoopIVs);
Value dstVal = createKrnl.load(data, dstLoopIVs);

// if next value is larger than current max value, update index
Value greaterThanMax = create.math.sgt(next, maxVal);
// if next value is smaller/larger than current value, update index
Value newDstVal = getCondition<ArgOp>(create, next, dstVal);
Value pos = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inLoopIVs[axis]);
idx = create.math.select(greaterThanMax, pos, idx);
idx = create.math.select(newDstVal, pos, idx);
createKrnl.store(idx, alloc, outLoopIVs);
});

Expand All @@ -139,9 +156,14 @@ struct ONNXArgMaxOpLowering : public ConversionPattern {
}
};

void populateLoweringONNXArgMaxOpPattern(RewritePatternSet &patterns,
void populateLoweringONNXArgMinMaxOpPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXArgMaxOpLowering>(typeConverter, ctx);
patterns.insert<
ONNXArgMinMaxOpLowering<mlir::ONNXArgMinOp, ONNXArgMinOpShapeHelper>>(
typeConverter, ctx);
patterns.insert<
ONNXArgMinMaxOpLowering<mlir::ONNXArgMaxOp, ONNXArgMaxOpShapeHelper>>(
typeConverter, ctx);
}

} // namespace onnx_mlir
10 changes: 8 additions & 2 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,14 @@ def get_test_models():
"test_argmax_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmax_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ArgMin

# ==OP== ArgMin
"test_argmin_no_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_default_axis_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_no_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== Asin
"test_asin_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_asin_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
Expand Down
16 changes: 16 additions & 0 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ func.func @test_default_argmax(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> {

// -----

//===----------------------------------------------------------------------===//
/// Test the default behavior of argmin when no information for the
/// permutation of the axes is provided and when a permutation is provided.
//===----------------------------------------------------------------------===//

func.func @test_default_argmin(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi64> {
%0 = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<*xi64>
"func.return"(%0) : (tensor<*xi64>) -> ()

// CHECK-LABEL: test_default_argmin
// CHECK: [[RES:%.+]] = "onnx.ArgMin"(%arg0) : (tensor<2x3x4xf32>) -> tensor<1x3x4xi64>
// CHECK: return [[RES]] : tensor<1x3x4xi64>
}
airMeng marked this conversation as resolved.
Show resolved Hide resolved

// -----

//===----------------------------------------------------------------------===//
/// Test the default behavior of transpose when no information for the
/// permutation of the axes is provided and when a permutation is provided.
Expand Down