Skip to content

Commit

Permalink
put shapeHelper into matchAndRewrite
Browse files Browse the repository at this point in the history
Signed-off-by: Hengyu Meng <[email protected]>
  • Loading branch information
airMeng committed Sep 30, 2022
1 parent 3662804 commit da66844
Showing 1 changed file with 13 additions and 37 deletions.
50 changes: 13 additions & 37 deletions src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,7 @@ inline Value getCondition<ONNXArgMaxOp>(
return create.math.sgt(next, dstVal);
}

template <typename ArgOp>
inline DimsExpr getOutputDims(ArgOp *op, typename ArgOp::Adaptor operandAdaptor,
mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal);

template <>
inline DimsExpr getOutputDims<ONNXArgMinOp>(ONNXArgMinOp *op,
typename ONNXArgMinOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter,
ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal) {
ONNXArgMinOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal);
auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");
return shapeHelper.dimsForOutput();
}

template <>
inline DimsExpr getOutputDims<ONNXArgMaxOp>(ONNXArgMaxOp *op,
typename ONNXArgMaxOp::Adaptor operandAdaptor, mlir::OpBuilder *rewriter,
ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal) {
ONNXArgMaxOpShapeHelper shapeHelper(op, rewriter, fGetDenseVal, fLoadVal);
auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");
return shapeHelper.dimsForOutput();
}

template <typename ArgOp>
template <typename ArgOp, typename OpShapeHelper>
struct ONNXArgMinMaxOpLowering : public ConversionPattern {
ONNXArgMinMaxOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(typeConverter, ArgOp::getOperationName(), 1, ctx) {}
Expand All @@ -80,9 +51,12 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern {
ArgOp argOp = llvm::cast<ArgOp>(op);

typename ArgOp::Adaptor operandAdaptor(operands);
DimsExpr OutputDims = getOutputDims<ArgOp>(&argOp, operandAdaptor,
&rewriter, krnl::getDenseElementAttributeFromKrnlValue,
OpShapeHelper shapeHelper(&argOp, &rewriter,
krnl::getDenseElementAttributeFromKrnlValue,
krnl::loadDenseElementArrayValueAtIndex);
auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
assert(succeeded(shapecomputed) && "Could not compute output shape");
DimsExpr outputDims = shapeHelper.dimsForOutput();

// Convert the reduced output type to MemRefType.
Type convertedType = typeConverter->convertType(*op->result_type_begin());
Expand Down Expand Up @@ -113,7 +87,7 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern {

// Insert alloc and dealloc
Value alloc = insertAllocAndDeallocSimple(
rewriter, op, reducedMemRefType, loc, OutputDims);
rewriter, op, reducedMemRefType, loc, outputDims);

// Constant Value
MathBuilder createMath(rewriter, loc);
Expand All @@ -125,13 +99,13 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern {
// 1. Krnl loops to initialize the result.
ValueRange initLoopDef = createKrnl.defineLoops(reducedRank);
SmallVector<IndexExpr, 4> initLbs(reducedRank, LiteralIndexExpr(0));
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, OutputDims,
createKrnl.iterateIE(initLoopDef, initLoopDef, initLbs, outputDims,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
createKrnl.store(minusOne, alloc, loopInd);
});

// 2. Krnl loop to calculate argmin/argmax.
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(createKrnl);
ValueRange calcLoopDef = createKrnl.defineLoops(dataRank);
SmallVector<IndexExpr, 4> lbs(dataRank, LiteralIndexExpr(0));
MemRefBoundsIndexCapture dataBounds(data);
Expand Down Expand Up @@ -184,9 +158,11 @@ struct ONNXArgMinMaxOpLowering : public ConversionPattern {

void populateLoweringONNXArgMinMaxOpPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXArgMinMaxOpLowering<mlir::ONNXArgMinOp>>(
patterns.insert<
ONNXArgMinMaxOpLowering<mlir::ONNXArgMinOp, ONNXArgMinOpShapeHelper>>(
typeConverter, ctx);
patterns.insert<ONNXArgMinMaxOpLowering<mlir::ONNXArgMaxOp>>(
patterns.insert<
ONNXArgMinMaxOpLowering<mlir::ONNXArgMaxOp, ONNXArgMaxOpShapeHelper>>(
typeConverter, ctx);
}

Expand Down

0 comments on commit da66844

Please sign in to comment.