Skip to content

Commit

Permalink
using DimsExpr for outputDims
Browse files Browse the repository at this point in the history
Signed-off-by: Hengyu Meng <[email protected]>
  • Loading branch information
airMeng committed Sep 29, 2022
1 parent 46c6194 commit 8907255
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ inline Value getCondition<ONNXArgMaxOp>(
}

template <typename ArgOp>
inline llvm::SmallVector<IndexExpr, 4> getOutputDims(ArgOp *op,
inline DimsExpr getOutputDims(ArgOp *op,
typename ArgOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter,
ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal);

template <>
inline llvm::SmallVector<IndexExpr, 4> getOutputDims<ONNXArgMinOp>(
inline DimsExpr getOutputDims<ONNXArgMinOp>(
ONNXArgMinOp *op, typename ONNXArgMinOp::Adaptor operandAdaptor,
mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal) {
Expand All @@ -57,7 +57,7 @@ inline llvm::SmallVector<IndexExpr, 4> getOutputDims<ONNXArgMinOp>(
}

template <>
inline llvm::SmallVector<IndexExpr, 4> getOutputDims<ONNXArgMaxOp>(
inline DimsExpr getOutputDims<ONNXArgMaxOp>(
ONNXArgMaxOp *op, typename ONNXArgMaxOp::Adaptor operandAdaptor,
mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal) {
Expand All @@ -81,7 +81,7 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern {
ArgOp argOp = llvm::cast<ArgOp>(op);

typename ArgOp::Adaptor operandAdaptor(operands);
auto OutputDims = getOutputDims<ArgOp>(&argOp, operandAdaptor, &rewriter,
DimsExpr OutputDims = getOutputDims<ArgOp>(&argOp, operandAdaptor, &rewriter,
krnl::getDenseElementAttributeFromKrnlValue,
krnl::loadDenseElementArrayValueAtIndex);

Expand Down Expand Up @@ -126,7 +126,7 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern {
// 1. Krnl loops to initialize the result.
ValueRange initLoopDef = createKrnl.defineLoops(reducedRank);
SmallVector<IndexExpr, 4> initLbs(reducedRank, LiteralIndexExpr(0));
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims[0],
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
createKrnl.store(minusOne, alloc, loopInd);
});
Expand Down
2 changes: 1 addition & 1 deletion test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_test_models():
"test_argmax_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmax_default_axis_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ArgMin
# ==OP== ArgMin
"test_argmin_no_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_argmin_default_axis_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
Expand Down

0 comments on commit 8907255

Please sign in to comment.