diff --git a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp index 20c39fb0763..68d59d73f7b 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp @@ -39,13 +39,13 @@ inline Value getCondition( } template -inline llvm::SmallVector getOutputDims(ArgOp *op, +inline DimsExpr getOutputDims(ArgOp *op, typename ArgOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, ArrayValueIndexCapture::LoadVal fLoadVal); template <> -inline llvm::SmallVector getOutputDims( +inline DimsExpr getOutputDims( ONNXArgMinOp *op, typename ONNXArgMinOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, ArrayValueIndexCapture::LoadVal fLoadVal) { @@ -57,7 +57,7 @@ inline llvm::SmallVector getOutputDims( } template <> -inline llvm::SmallVector getOutputDims( +inline DimsExpr getOutputDims( ONNXArgMaxOp *op, typename ONNXArgMaxOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, ArrayValueIndexCapture::LoadVal fLoadVal) { @@ -81,7 +81,7 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern { ArgOp argOp = llvm::cast(op); typename ArgOp::Adaptor operandAdaptor(operands); - auto OutputDims = getOutputDims(&argOp, operandAdaptor, &rewriter, + DimsExpr OutputDims = getOutputDims(&argOp, operandAdaptor, &rewriter, krnl::getDenseElementAttributeFromKrnlValue, krnl::loadDenseElementArrayValueAtIndex); @@ -126,7 +126,7 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern { // 1. Krnl loops to initialize the result. ValueRange initLoopDef = createKrnl.defineLoops(reducedRank); SmallVector initLbs(reducedRank, LiteralIndexExpr(0)); - createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims[0], + createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { createKrnl.store(minusOne, alloc, loopInd); }); diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 402cc88ef19..099bd488a18 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -113,7 +113,7 @@ def get_test_models(): "test_argmax_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_argmax_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, - # ArgMin + # ==OP== ArgMin "test_argmin_no_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_argmin_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_argmin_default_axis_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},