From dab54ab63dbc414bd2c11c6f7c890c5a45564be3 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Tue, 19 Nov 2024 16:55:42 -0800 Subject: [PATCH] Fix dynamic shape issue --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 2 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 102 ++++++++---------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 8 +- 3 files changed, 52 insertions(+), 60 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index dc76c9689480..fccbbc2921f3 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3728,7 +3728,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value minScores = rewriter.create( binder.getLoc(), Torch::ValueTensorType::get(binder.op->getContext(), - SmallVector{1}, + SmallVector{}, rewriter.getF32Type()), scores); minScores = rewriter.create( diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index eb835c2554f2..0c6801eed48a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -10677,7 +10677,6 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { Value boxes = op.getDets(); Value scores = op.getScores(); Value iouThreshold = op.getIouThreshold(); - Type resultType = op.getType(); Value cst0 = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); @@ -10693,22 +10692,11 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); - // Sort scores in descending order - // Use the sorted indices to iterate boxes - auto scoresType = dyn_cast(scores.getType()); - auto sortIndicesType = scoresType.getWithSizesAndDtype( - scoresType.getOptionalSizes(), - IntegerType::get(context, 64, IntegerType::Signed)); - auto sortResult = rewriter.create( - loc, TypeRange({scores.getType(), sortIndicesType}), scores, - /*dim=*/cst0, /*descending=*/cstTrue); - // Get number of boxes for the loop count auto boxesTensorType = dyn_cast(boxes.getType()); auto dType = boxesTensorType.getDtype(); int64_t boxesSize = boxesTensorType.getSizes()[0]; - Value len = rewriter.create( - loc, rewriter.getI64IntegerAttr(boxesSize)); + Value len = rewriter.create(loc, boxes, /*dim=*/cst0); // Calculate the area of each box: (x2 - x1) * (y2 - y1) auto sliceTy = rewriter.getType( @@ -10727,15 +10715,22 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse, /*dtype=*/cstNone); + // Sort scores in descending order + // Use the sorted indices to iterate boxes + auto scoresType = dyn_cast(scores.getType()); + auto intTensorType = scoresType.getWithSizesAndDtype( + scoresType.getOptionalSizes(), + IntegerType::get(context, 64, IntegerType::Signed)); + auto sortResult = rewriter.create( + loc, TypeRange({scores.getType(), intTensorType}), scores, + /*dim=*/cst0, /*descending=*/cstTrue); + // Create a mask to mark if we keep the boxes - Value maskShapeList = rewriter.create( + Value lenShapeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), SmallVector{len}); - auto maskTy = rewriter.getType( - SmallVector{boxesSize}, - rewriter.getIntegerType(64, /*signed=*/true)); Value mask = rewriter.create( - loc, maskTy, maskShapeList, cstNone, cstNone, cstNone, cstNone); + loc, intTensorType, lenShapeList, cstNone, cstNone, cstNone, cstNone); Value zeroShapeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), SmallVector{cst1}); @@ -10744,15 +10739,10 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { Value falseMask = rewriter.create( loc, zeroTy, zeroShapeList, cstNone, cstNone, cstNone, cstNone); - // Create a zero tensor for output - auto resultTensorType = dyn_cast(resultType); - Value resultLen = rewriter.create( - loc, rewriter.getI64IntegerAttr(resultTensorType.getSizes()[0])); - Value resultShapeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - SmallVector{resultLen}); - Value output = rewriter.create( - loc, resultType, resultShapeList, cstNone, cstNone, cstNone, cstNone); + // Create an empty tensor for result + Value result = rewriter.create( + loc, intTensorType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); auto intTy = rewriter.getType(); auto rowSliceTy = @@ -10761,8 +10751,6 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { rewriter.getType(SmallVector{1, 2}, dType); auto extractTy = rewriter.getType( SmallVector{1}, rewriter.getIntegerType(64, true)); - Value cnt = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); Value float0 = rewriter.create( loc, rewriter.getFloatAttr(dType, 0.0)); auto scalarFloatType = rewriter.getType( @@ -10776,17 +10764,18 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { // 4. Loop through the rest boxes in sorted indices // 5. Suppress the box if the corresponding IoU is larger than threshold auto loop1 = rewriter.create( - loc, TypeRange({maskTy, resultType, intTy}), len, cstTrue, - ValueRange({mask, output, cnt})); + loc, TypeRange({intTensorType, intTensorType, intTy}), len, cstTrue, + ValueRange({mask, result, cst0})); { PatternRewriter::InsertionGuard guard(rewriter); - Block *rowLoopBody = rewriter.createBlock( + Block *loopBody1 = rewriter.createBlock( &loop1.getRegion(), loop1.getRegion().begin(), - TypeRange({intTy, maskTy, resultType, intTy}), {loc, loc, loc, loc}); - Value i = rowLoopBody->getArgument(0); - Value mask1 = rowLoopBody->getArgument(1); - Value curOutput = rowLoopBody->getArgument(2); - Value curCnt = rowLoopBody->getArgument(3); + TypeRange({intTy, intTensorType, intTensorType, intTy}), + {loc, loc, loc, loc}); + Value i = loopBody1->getArgument(0); + Value mask1 = loopBody1->getArgument(1); + Value curResult = loopBody1->getArgument(2); + Value curCnt = loopBody1->getArgument(3); // Extract the mask to check if the base box is suppressed Value extract = rewriter.create( @@ -10795,20 +10784,20 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { Value iskept = rewriter.create( loc, rewriter.getType(), scalar); auto ifFilterOthers = rewriter.create( - loc, TypeRange({maskTy, resultType, intTy}), iskept); + loc, TypeRange({intTensorType, intTensorType, intTy}), iskept); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifFilterOthers.getThenRegion(), ifFilterOthers.getThenRegion().begin()); - // Fill the selected indices into output + // Scatter the selected indices into result Value extractIdx1 = rewriter.create( loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, /*index=*/i); - Value nextCnt = rewriter.create(loc, curCnt, cst1); - Value updatedOutput = rewriter.create( - loc, resultType, curOutput, extractIdx1, /*dim=*/cst0, - /*start=*/curCnt, /*end=*/nextCnt, /*step=*/cst1); + Value next = rewriter.create(loc, curCnt, cst1); + Value updatedResult = rewriter.create( + loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0, + /*start=*/curCnt, /*end=*/next, /*step=*/cst1); // Get the coordinates of base box Value idx1 = @@ -10818,6 +10807,7 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { loc, rowSliceTy, boxes, /*dim=*/cst0, /*start=*/idx1, /*end=*/idx1End, /*step=*/cst1); + // Calculate IoUs: intersectionArea / unionArea // Intersection area = intersectionWidth * intersectionHeight Value point1 = rewriter.create( loc, pointTy, slice1, @@ -10849,22 +10839,22 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { loc, areaTy, intersectionArea, unionArea); // Loop through the rest of boxes in sorted indices - auto loop2 = rewriter.create(loc, maskTy, len, - cstTrue, mask1); + auto loop2 = rewriter.create(loc, intTensorType, + len, cstTrue, mask1); { PatternRewriter::InsertionGuard guard(rewriter); - Block *colLoopBody = rewriter.createBlock( + Block *loopBody2 = rewriter.createBlock( &loop2.getRegion(), loop2.getRegion().begin(), - TypeRange({intTy, maskTy}), {loc, loc}); - Value j = colLoopBody->getArgument(0); - Value mask2 = colLoopBody->getArgument(1); + TypeRange({intTy, intTensorType}), {loc, loc}); + Value j = loopBody2->getArgument(0); + Value mask2 = loopBody2->getArgument(1); // Check if current index is out of range j = rewriter.create(loc, j, i); j = rewriter.create(loc, j, cst1); Value isInRange = rewriter.create(loc, j, len); auto ifCalculateIou = rewriter.create( - loc, TypeRange({maskTy}), isInRange); + loc, TypeRange({intTensorType}), isInRange); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifCalculateIou.getThenRegion(), @@ -10887,7 +10877,7 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { loc, curIoU, iouThreshold); auto ifUnmask = rewriter.create( - loc, TypeRange({maskTy}), isSuppressed); + loc, TypeRange({intTensorType}), isSuppressed); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifUnmask.getThenRegion(), @@ -10896,7 +10886,7 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { // Update the mask if suppress Value jEnd = rewriter.create(loc, j, cst1); Value updatedMask = rewriter.create( - loc, maskTy, mask2, falseMask, /*dim=*/cst0, + loc, intTensorType, mask2, falseMask, /*dim=*/cst0, /*start=*/j, /*end=*/jEnd, /*step=*/cst1); rewriter.create(loc, updatedMask); } @@ -10921,21 +10911,23 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { } rewriter.create( - loc, ValueRange({loop2.getResult(0), updatedOutput, nextCnt})); + loc, ValueRange({loop2.getResult(0), updatedResult, next})); } { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifFilterOthers.getElseRegion(), ifFilterOthers.getElseRegion().begin()); rewriter.create( - loc, ValueRange({mask1, curOutput, curCnt})); + loc, ValueRange({mask1, curResult, curCnt})); } rewriter.create(loc, cstTrue, ifFilterOthers.getResults()); } - rewriter.replaceOp(op, loop1.getResult(1)); + rewriter.replaceOpWithNewOp( + op, op.getType(), loop1.getResult(1), /*dim=*/cst0, /*start=*/cst0, + /*end=*/loop1.getResult(2), /*step=*/cst1); return success(); } }; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 8ea63059efd9..33dc51f14cce 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -2054,8 +2054,8 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4] // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32> // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[1],f32> - // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float @@ -2106,8 +2106,8 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32> // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> - // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float