diff --git a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp index 69fc7136fd..79e9300bf4 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp @@ -38,36 +38,7 @@ inline Value getCondition( return create.math.sgt(next, dstVal); } -template -inline DimsExpr getOutputDims(ArgOp *op, typename ArgOp::Adaptor operandAdaptor, - mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, - ArrayValueIndexCapture::LoadVal fLoadVal); - -template <> -inline DimsExpr getOutputDims(ONNXArgMinOp *op, - typename ONNXArgMinOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter, - ArrayValueIndexCapture::GetDenseVal fGetDenseVal, - ArrayValueIndexCapture::LoadVal fLoadVal) { - ONNXArgMinOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal); - auto shapecomputed = shapeHelper.computeShape(operandAdaptor); - (void)shapecomputed; - assert(!failed(shapecomputed) && "expected to succeed"); - return shapeHelper.dimsForOutput(); -} - -template <> -inline DimsExpr getOutputDims(ONNXArgMaxOp *op, - typename ONNXArgMaxOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter, - ArrayValueIndexCapture::GetDenseVal fGetDenseVal, - ArrayValueIndexCapture::LoadVal fLoadVal) { - ONNXArgMaxOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal); - auto shapecomputed = shapeHelper.computeShape(operandAdaptor); - (void)shapecomputed; - assert(!failed(shapecomputed) && "expected to succeed"); - return shapeHelper.dimsForOutput(); -} - -template +template struct ONNXArgMinMaxOpLowering : public ConversionPattern { ONNXArgMinMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) : ConversionPattern(typeConverter, ArgOp::getOperationName(), 1, ctx) {} @@ -80,9 +51,12 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern { ArgOp argOp = llvm::cast(op); typename ArgOp::Adaptor operandAdaptor(operands); - DimsExpr OutputDims = getOutputDims(&argOp, operandAdaptor, - &rewriter, krnl::getDenseElementAttributeFromKrnlValue, + OpShapeHelper shapeHelper(&argOp, &rewriter, + krnl::getDenseElementAttributeFromKrnlValue, krnl::loadDenseElementArrayValueAtIndex); + auto shapecomputed = shapeHelper.computeShape(operandAdaptor); + assert(succeeded(shapecomputed) && "Could not compute output shape"); + DimsExpr outputDims = shapeHelper.dimsForOutput(); // Convert the reduced output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); @@ -113,7 +87,7 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern { // Insert alloc and dealloc Value alloc = insertAllocAndDeallocSimple( - rewriter, op, reducedMemRefType, loc, OutputDims); + rewriter, op, reducedMemRefType, loc, outputDims); // Constant Value MathBuilder createMath(rewriter, loc); @@ -125,13 +99,13 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern { // 1. Krnl loops to initialize the result. ValueRange initLoopDef = createKrnl.defineLoops(reducedRank); SmallVector initLbs(reducedRank, LiteralIndexExpr(0)); - createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims, + createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, outputDims, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { createKrnl.store(minusOne, alloc, loopInd); }); // 2. Krnl loop to calculate argmin/argmax. - MultiDialectBuilder create(rewriter, loc); + MultiDialectBuilder create(createKrnl); ValueRange calcLoopDef = createKrnl.defineLoops(dataRank); SmallVector lbs(dataRank, LiteralIndexExpr(0)); MemRefBoundsIndexCapture dataBounds(data); @@ -184,9 +158,11 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern { void populateLoweringONNXArgMinMaxOpPattern(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx) { - patterns.insert>( + patterns.insert< + ONNXArgMinMaxOpLowering>( typeConverter, ctx); - patterns.insert>( + patterns.insert< + ONNXArgMinMaxOpLowering>( typeConverter, ctx); }