Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NMS op lowering #3871

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
46 changes: 27 additions & 19 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3687,9 +3687,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"attribute value to be 0");

// TODO: Add support for optional arguments to be absent.
if (operands.size() != 5)
if (operands.size() < 4)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected all 5 args to be present");
binder.op, "unimplemented: expected at least 4 arguments");

// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
Expand Down Expand Up @@ -3721,27 +3721,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// If score_threshold > min(scores) then the op can't be lowered since
// the torchvision::nms op doesn't have support for handling the
// score_threshold arg.
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);

Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
if (operands.size() == 5) {
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
SmallVector<int64_t>{},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);

Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
}

Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
binder.getLoc(), resultType, boxes, scores, iouThreshold);
binder.getLoc(), nmsTy, boxes, scores, iouThreshold);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
Expand Down
270 changes: 270 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10666,6 +10666,273 @@ class DecomposeAtenFloatPowerTensorTensorOp
};
} // namespace

namespace {
class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TorchvisionNmsOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value boxes = op.getDets();
Value scores = op.getScores();
Value iouThreshold = op.getIouThreshold();

Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cst2 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(4));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));

// Get number of boxes for the loop count
auto boxesTensorType = dyn_cast<Torch::ValueTensorType>(boxes.getType());
auto dType = boxesTensorType.getDtype();
int64_t boxesSize = boxesTensorType.getSizes()[0];
Value len = rewriter.create<AtenSizeIntOp>(loc, boxes, /*dim=*/cst0);

// Calculate the area of each box: (x2 - x1) * (y2 - y1)
auto sliceTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{boxesSize, 2}, dType);
Value lowSlice = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, boxes,
/*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1);
Value highSlice = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, boxes,
/*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1);
Value distance = rewriter.create<Torch::AtenSubTensorOp>(
loc, sliceTy, highSlice, lowSlice, cst1);
auto areaTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{boxesSize}, dType);
Value area = rewriter.create<Torch::AtenProdDimIntOp>(
loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);

// Sort scores in descending order
// Use the sorted indices to iterate boxes
auto scoresType = dyn_cast<BaseTensorType>(scores.getType());
auto intTensorType = scoresType.getWithSizesAndDtype(
scoresType.getOptionalSizes(),
IntegerType::get(context, 64, IntegerType::Signed));
auto sortResult = rewriter.create<Torch::AtenSortOp>(
loc, TypeRange({scores.getType(), intTensorType}), scores,
/*dim=*/cst0, /*descending=*/cstTrue);

// Create a mask to mark if we keep the boxes
Value lenShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
SmallVector<Value>{len});
Value mask = rewriter.create<Torch::AtenOnesOp>(
loc, intTensorType, lenShapeList, cstNone, cstNone, cstNone, cstNone);
Value zeroShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
SmallVector<Value>{cst1});
auto zeroTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{1}, rewriter.getIntegerType(64, /*signed=*/true));
Value falseMask = rewriter.create<Torch::AtenZerosOp>(
loc, zeroTy, zeroShapeList, cstNone, cstNone, cstNone, cstNone);

// Create an empty tensor for result
Value result = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
loc, intTensorType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone);

auto intTy = rewriter.getType<Torch::IntType>();
auto rowSliceTy =
rewriter.getType<ValueTensorType>(SmallVector<int64_t>{1, 4}, dType);
auto pointTy =
rewriter.getType<ValueTensorType>(SmallVector<int64_t>{1, 2}, dType);
auto extractTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{1}, rewriter.getIntegerType(64, true));
Value float0 = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getFloatAttr(dType, 0.0));
auto scalarFloatType = rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{1}, dType);
Value float0Tensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, scalarFloatType, float0);

// 1. Loop through the boxes based on sorted indices
// 2. Add the current box to result if it's not suppressed
// 3. Calculate the IoUs with all boxes
// 4. Loop through the rest boxes in sorted indices
// 5. Suppress the box if the corresponding IoU is larger than threshold
auto loop1 = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({intTensorType, intTensorType, intTy}), len, cstTrue,
ValueRange({mask, result, cst0}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *loopBody1 = rewriter.createBlock(
&loop1.getRegion(), loop1.getRegion().begin(),
TypeRange({intTy, intTensorType, intTensorType, intTy}),
{loc, loc, loc, loc});
Value i = loopBody1->getArgument(0);
Value mask1 = loopBody1->getArgument(1);
Value curResult = loopBody1->getArgument(2);
Value curCnt = loopBody1->getArgument(3);

// Extract the mask to check if the base box is suppressed
Value extract = rewriter.create<AtenSelectIntOp>(
loc, extractTy, mask1, /*dim=*/cst0, /*index=*/i);
Value scalar = rewriter.create<Torch::AtenItemOp>(loc, intTy, extract);
Value iskept = rewriter.create<Torch::AtenBoolIntOp>(
loc, rewriter.getType<Torch::BoolType>(), scalar);
auto ifFilterOthers = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({intTensorType, intTensorType, intTy}), iskept);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifFilterOthers.getThenRegion(),
ifFilterOthers.getThenRegion().begin());

// Scatter the selected indices into result
Value extractIdx1 = rewriter.create<AtenSelectIntOp>(
loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0,
/*index=*/i);
Value next = rewriter.create<Torch::AtenAddIntOp>(loc, curCnt, cst1);
Value updatedResult = rewriter.create<Torch::AtenSliceScatterOp>(
loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0,
/*start=*/curCnt, /*end=*/next, /*step=*/cst1);

// Get the coordinates of base box
Value idx1 =
rewriter.create<Torch::AtenItemOp>(loc, intTy, extractIdx1);
Value idx1End = rewriter.create<Torch::AtenAddIntOp>(loc, idx1, cst1);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
loc, rowSliceTy, boxes,
/*dim=*/cst0, /*start=*/idx1, /*end=*/idx1End, /*step=*/cst1);

// Calculate IoUs: intersectionArea / unionArea
// Intersection area = intersectionWidth * intersectionHeight
Value point1 = rewriter.create<AtenSliceTensorOp>(
loc, pointTy, slice1,
/*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1);
Value point2 = rewriter.create<AtenSliceTensorOp>(
loc, pointTy, slice1,
/*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1);
Value innerLow = rewriter.create<Torch::AtenMaximumOp>(
loc, sliceTy, lowSlice, point1);
Value innerHigh = rewriter.create<Torch::AtenMinimumOp>(
loc, sliceTy, highSlice, point2);
Value innerDistance = rewriter.create<Torch::AtenSubTensorOp>(
loc, sliceTy, innerHigh, innerLow, cst1);
innerDistance = rewriter.create<Torch::AtenMaximumOp>(
loc, sliceTy, innerDistance, float0Tensor);
Value intersectionArea = rewriter.create<Torch::AtenProdDimIntOp>(
loc, areaTy, innerDistance, /*dim=*/cst1, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);
Value iEnd = rewriter.create<Torch::AtenAddIntOp>(loc, i, cst1);
Value curArea = rewriter.create<AtenSliceTensorOp>(
loc, scalarFloatType, area,
/*dim=*/cst0, /*start=*/i, /*end=*/iEnd, /*step=*/cst1);
// Union area = area1 + area2 - intersectionArea
Value unionArea = rewriter.create<Torch::AtenAddTensorOp>(
loc, areaTy, area, curArea, cst1);
unionArea = rewriter.create<Torch::AtenSubTensorOp>(
loc, areaTy, unionArea, intersectionArea, cst1);
Value iou = rewriter.create<Torch::AtenDivTensorOp>(
loc, areaTy, intersectionArea, unionArea);

// Loop through the rest of boxes in sorted indices
auto loop2 = rewriter.create<Torch::PrimLoopOp>(loc, intTensorType, len,
cstTrue, mask1);
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *loopBody2 = rewriter.createBlock(
&loop2.getRegion(), loop2.getRegion().begin(),
TypeRange({intTy, intTensorType}), {loc, loc});
Value j = loopBody2->getArgument(0);
Value mask2 = loopBody2->getArgument(1);

// Check if current index is out of range
j = rewriter.create<Torch::AtenAddIntOp>(loc, j, i);
j = rewriter.create<Torch::AtenAddIntOp>(loc, j, cst1);
Value isInRange = rewriter.create<Torch::AtenLtIntOp>(loc, j, len);
auto ifCalculateIou = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({intTensorType}), isInRange);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifCalculateIou.getThenRegion(),
ifCalculateIou.getThenRegion().begin());

// Retrieve IoU and check if suppress the box
Value extractIdx2 = rewriter.create<AtenSelectIntOp>(
loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0,
/*index=*/j);
Value idx2 =
rewriter.create<Torch::AtenItemOp>(loc, intTy, extractIdx2);
Value idx2End =
rewriter.create<Torch::AtenAddIntOp>(loc, idx2, cst1);
Value curIoU = rewriter.create<AtenSliceTensorOp>(
loc, scalarFloatType, iou,
/*dim=*/cst0, /*start=*/idx2, /*end=*/idx2End, /*step=*/cst1);
curIoU = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), curIoU);
Value isSuppressed = rewriter.create<Torch::AtenGtFloatOp>(
loc, curIoU, iouThreshold);

auto ifUnmask = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({intTensorType}), isSuppressed);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifUnmask.getThenRegion(),
ifUnmask.getThenRegion().begin());

// Update the mask if suppress
Value jEnd = rewriter.create<Torch::AtenAddIntOp>(loc, j, cst1);
Value updatedMask = rewriter.create<Torch::AtenSliceScatterOp>(
loc, intTensorType, mask2, falseMask, /*dim=*/cst0,
/*start=*/j, /*end=*/jEnd, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, updatedMask);
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifUnmask.getElseRegion(),
ifUnmask.getElseRegion().begin());
rewriter.create<Torch::PrimIfYieldOp>(loc, mask2);
}

rewriter.create<Torch::PrimIfYieldOp>(loc, ifUnmask.getResult(0));
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifCalculateIou.getElseRegion(),
ifCalculateIou.getElseRegion().begin());
rewriter.create<Torch::PrimIfYieldOp>(loc, mask2);
}

rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue, ifCalculateIou.getResult(0));
}

rewriter.create<Torch::PrimIfYieldOp>(
loc, ValueRange({loop2.getResult(0), updatedResult, next}));
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifFilterOthers.getElseRegion(),
ifFilterOthers.getElseRegion().begin());
rewriter.create<Torch::PrimIfYieldOp>(
loc, ValueRange({mask1, curResult, curCnt}));
}

rewriter.create<Torch::PrimLoopConditionOp>(loc, cstTrue,
ifFilterOthers.getResults());
}

rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
op, op.getType(), loop1.getResult(1), /*dim=*/cst0, /*start=*/cst0,
/*end=*/loop1.getResult(2), /*step=*/cst1);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -10950,6 +11217,9 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<
DecomposeAtenFMaxMinOp<AtenFminOp, AtenMinimumOp>>(patterns);

// Torchvision ops
addPatternIfTargetOpIsIllegal<DecomposeTorchvisionNmsOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
Loading
Loading