Skip to content

Commit

Permalink
Fix dynamic shape issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Nov 22, 2024
1 parent 23869b4 commit dab54ab
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 60 deletions.
2 changes: 1 addition & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3728,7 +3728,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
SmallVector<int64_t>{1},
SmallVector<int64_t>{},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
Expand Down
102 changes: 47 additions & 55 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10677,7 +10677,6 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
Value boxes = op.getDets();
Value scores = op.getScores();
Value iouThreshold = op.getIouThreshold();
Type resultType = op.getType();

Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Expand All @@ -10693,22 +10692,11 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));

// Sort scores in descending order
// Use the sorted indices to iterate boxes
auto scoresType = dyn_cast<BaseTensorType>(scores.getType());
auto sortIndicesType = scoresType.getWithSizesAndDtype(
scoresType.getOptionalSizes(),
IntegerType::get(context, 64, IntegerType::Signed));
auto sortResult = rewriter.create<Torch::AtenSortOp>(
loc, TypeRange({scores.getType(), sortIndicesType}), scores,
/*dim=*/cst0, /*descending=*/cstTrue);

// Get number of boxes for the loop count
auto boxesTensorType = dyn_cast<Torch::ValueTensorType>(boxes.getType());
auto dType = boxesTensorType.getDtype();
int64_t boxesSize = boxesTensorType.getSizes()[0];
Value len = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(boxesSize));
Value len = rewriter.create<AtenSizeIntOp>(loc, boxes, /*dim=*/cst0);

// Calculate the area of each box: (x2 - x1) * (y2 - y1)
auto sliceTy = rewriter.getType<ValueTensorType>(
Expand All @@ -10727,15 +10715,22 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);

// Sort scores in descending order
// Use the sorted indices to iterate boxes
auto scoresType = dyn_cast<BaseTensorType>(scores.getType());
auto intTensorType = scoresType.getWithSizesAndDtype(
scoresType.getOptionalSizes(),
IntegerType::get(context, 64, IntegerType::Signed));
auto sortResult = rewriter.create<Torch::AtenSortOp>(
loc, TypeRange({scores.getType(), intTensorType}), scores,
/*dim=*/cst0, /*descending=*/cstTrue);

// Create a mask to mark if we keep the boxes
Value maskShapeList = rewriter.create<Torch::PrimListConstructOp>(
Value lenShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
SmallVector<Value>{len});
auto maskTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{boxesSize},
rewriter.getIntegerType(64, /*signed=*/true));
Value mask = rewriter.create<Torch::AtenOnesOp>(
loc, maskTy, maskShapeList, cstNone, cstNone, cstNone, cstNone);
loc, intTensorType, lenShapeList, cstNone, cstNone, cstNone, cstNone);
Value zeroShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
SmallVector<Value>{cst1});
Expand All @@ -10744,15 +10739,10 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
Value falseMask = rewriter.create<Torch::AtenZerosOp>(
loc, zeroTy, zeroShapeList, cstNone, cstNone, cstNone, cstNone);

// Create a zero tensor for output
auto resultTensorType = dyn_cast<Torch::ValueTensorType>(resultType);
Value resultLen = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(resultTensorType.getSizes()[0]));
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
SmallVector<Value>{resultLen});
Value output = rewriter.create<Torch::AtenZerosOp>(
loc, resultType, resultShapeList, cstNone, cstNone, cstNone, cstNone);
// Create an empty tensor for result
Value result = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
loc, intTensorType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone);

auto intTy = rewriter.getType<Torch::IntType>();
auto rowSliceTy =
Expand All @@ -10761,8 +10751,6 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
rewriter.getType<ValueTensorType>(SmallVector<int64_t>{1, 2}, dType);
auto extractTy = rewriter.getType<ValueTensorType>(
SmallVector<int64_t>{1}, rewriter.getIntegerType(64, true));
Value cnt = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value float0 = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getFloatAttr(dType, 0.0));
auto scalarFloatType = rewriter.getType<Torch::ValueTensorType>(
Expand All @@ -10776,17 +10764,18 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
// 4. Loop through the rest boxes in sorted indices
// 5. Suppress the box if the corresponding IoU is larger than threshold
auto loop1 = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({maskTy, resultType, intTy}), len, cstTrue,
ValueRange({mask, output, cnt}));
loc, TypeRange({intTensorType, intTensorType, intTy}), len, cstTrue,
ValueRange({mask, result, cst0}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *rowLoopBody = rewriter.createBlock(
Block *loopBody1 = rewriter.createBlock(
&loop1.getRegion(), loop1.getRegion().begin(),
TypeRange({intTy, maskTy, resultType, intTy}), {loc, loc, loc, loc});
Value i = rowLoopBody->getArgument(0);
Value mask1 = rowLoopBody->getArgument(1);
Value curOutput = rowLoopBody->getArgument(2);
Value curCnt = rowLoopBody->getArgument(3);
TypeRange({intTy, intTensorType, intTensorType, intTy}),
{loc, loc, loc, loc});
Value i = loopBody1->getArgument(0);
Value mask1 = loopBody1->getArgument(1);
Value curResult = loopBody1->getArgument(2);
Value curCnt = loopBody1->getArgument(3);

// Extract the mask to check if the base box is suppressed
Value extract = rewriter.create<AtenSelectIntOp>(
Expand All @@ -10795,20 +10784,20 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
Value iskept = rewriter.create<Torch::AtenBoolIntOp>(
loc, rewriter.getType<Torch::BoolType>(), scalar);
auto ifFilterOthers = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({maskTy, resultType, intTy}), iskept);
loc, TypeRange({intTensorType, intTensorType, intTy}), iskept);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifFilterOthers.getThenRegion(),
ifFilterOthers.getThenRegion().begin());

// Fill the selected indices into output
// Scatter the selected indices into result
Value extractIdx1 = rewriter.create<AtenSelectIntOp>(
loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0,
/*index=*/i);
Value nextCnt = rewriter.create<Torch::AtenAddIntOp>(loc, curCnt, cst1);
Value updatedOutput = rewriter.create<Torch::AtenSliceScatterOp>(
loc, resultType, curOutput, extractIdx1, /*dim=*/cst0,
/*start=*/curCnt, /*end=*/nextCnt, /*step=*/cst1);
Value next = rewriter.create<Torch::AtenAddIntOp>(loc, curCnt, cst1);
Value updatedResult = rewriter.create<Torch::AtenSliceScatterOp>(
loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0,
/*start=*/curCnt, /*end=*/next, /*step=*/cst1);

// Get the coordinates of base box
Value idx1 =
Expand All @@ -10818,6 +10807,7 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
loc, rowSliceTy, boxes,
/*dim=*/cst0, /*start=*/idx1, /*end=*/idx1End, /*step=*/cst1);

// Calculate IoUs: intersectionArea / unionArea
// Intersection area = intersectionWidth * intersectionHeight
Value point1 = rewriter.create<AtenSliceTensorOp>(
loc, pointTy, slice1,
Expand Down Expand Up @@ -10849,22 +10839,22 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
loc, areaTy, intersectionArea, unionArea);

// Loop through the rest of boxes in sorted indices
auto loop2 = rewriter.create<Torch::PrimLoopOp>(loc, maskTy, len,
cstTrue, mask1);
auto loop2 = rewriter.create<Torch::PrimLoopOp>(loc, intTensorType,
len, cstTrue, mask1);
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *colLoopBody = rewriter.createBlock(
Block *loopBody2 = rewriter.createBlock(
&loop2.getRegion(), loop2.getRegion().begin(),
TypeRange({intTy, maskTy}), {loc, loc});
Value j = colLoopBody->getArgument(0);
Value mask2 = colLoopBody->getArgument(1);
TypeRange({intTy, intTensorType}), {loc, loc});
Value j = loopBody2->getArgument(0);
Value mask2 = loopBody2->getArgument(1);

// Check if current index is out of range
j = rewriter.create<Torch::AtenAddIntOp>(loc, j, i);
j = rewriter.create<Torch::AtenAddIntOp>(loc, j, cst1);
Value isInRange = rewriter.create<Torch::AtenLtIntOp>(loc, j, len);
auto ifCalculateIou = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({maskTy}), isInRange);
loc, TypeRange({intTensorType}), isInRange);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifCalculateIou.getThenRegion(),
Expand All @@ -10887,7 +10877,7 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
loc, curIoU, iouThreshold);

auto ifUnmask = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({maskTy}), isSuppressed);
loc, TypeRange({intTensorType}), isSuppressed);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifUnmask.getThenRegion(),
Expand All @@ -10896,7 +10886,7 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
// Update the mask if suppress
Value jEnd = rewriter.create<Torch::AtenAddIntOp>(loc, j, cst1);
Value updatedMask = rewriter.create<Torch::AtenSliceScatterOp>(
loc, maskTy, mask2, falseMask, /*dim=*/cst0,
loc, intTensorType, mask2, falseMask, /*dim=*/cst0,
/*start=*/j, /*end=*/jEnd, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, updatedMask);
}
Expand All @@ -10921,21 +10911,23 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern<TorchvisionNmsOp> {
}

rewriter.create<Torch::PrimIfYieldOp>(
loc, ValueRange({loop2.getResult(0), updatedOutput, nextCnt}));
loc, ValueRange({loop2.getResult(0), updatedResult, next}));
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifFilterOthers.getElseRegion(),
ifFilterOthers.getElseRegion().begin());
rewriter.create<Torch::PrimIfYieldOp>(
loc, ValueRange({mask1, curOutput, curCnt}));
loc, ValueRange({mask1, curResult, curCnt}));
}

rewriter.create<Torch::PrimLoopConditionOp>(loc, cstTrue,
ifFilterOthers.getResults());
}

rewriter.replaceOp(op, loop1.getResult(1));
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
op, op.getType(), loop1.getResult(1), /*dim=*/cst0, /*start=*/cst0,
/*end=*/loop1.getResult(2), /*step=*/cst1);
return success();
}
};
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2054,8 +2054,8 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4]
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[1],f32>
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
Expand Down Expand Up @@ -2106,8 +2106,8 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
Expand Down

0 comments on commit dab54ab

Please sign in to comment.