diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
index 3db33aee1f1c..ad0b2b8cd500 100644
--- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
+++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
@@ -3701,63 +3701,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
           return rewriter.notifyMatchFailure(
               binder.op, "expected center_point_box attribute to be 0 or 1");
 
-        // TODO: Support multiple batches and classes
-        // Squeeze the boxes and scores tensor.
-        // In Onnx, the shape of boxes is [BxNx4] while the
-        // torchvision expects it to be of shape [Nx4]. Similarly, for
-        // the scores tensor shape in Onnx is [BxCxN] while the
-        // torchvision expects it to be of shape [N].
+        Value cst0 = rewriter.create<Torch::ConstantIntOp>(loc, 0);
+        Value cst1 = rewriter.create<Torch::ConstantIntOp>(loc, 1);
+        Value cst2 = rewriter.create<Torch::ConstantIntOp>(loc, 2);
+        Value cst3 = rewriter.create<Torch::ConstantIntOp>(loc, 3);
+        Value cst4 = rewriter.create<Torch::ConstantIntOp>(loc, 4);
+        Value cst2F = rewriter.create<Torch::ConstantFloatOp>(
+            loc, rewriter.getF64FloatAttr(2.0));
+        Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
+        Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
+            loc, rewriter.getBoolAttr(true));
+        Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
+            loc, rewriter.getBoolAttr(false));
+
+        // In Onnx, the shape of boxes is [BxNx4] and that of scores is [BxCxN]
         Value boxes = operands[0], scores = operands[1];
-        FailureOr<Value> squeezedBoxes =
-            Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
-        if (failed(squeezedBoxes))
-          return rewriter.notifyMatchFailure(binder.op,
-                                             "failed to squeeze boxes tensor");
-        FailureOr<Value> squeezedScores =
-            Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
-        if (failed(squeezedScores))
-          return rewriter.notifyMatchFailure(binder.op,
-                                             "failed to squeeze scores tensor");
-        squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
-                                              squeezedScores.value());
-        if (failed(squeezedScores))
-          return rewriter.notifyMatchFailure(binder.op,
-                                             "failed to squeeze scores tensor");
-        boxes = squeezedBoxes.value();
-        scores = squeezedScores.value();
+
+        auto boxesTensorType = cast<Torch::ValueTensorType>(boxes.getType());
+        auto scoreTensorType = cast<Torch::ValueTensorType>(scores.getType());
+        auto boxSlicedType = rewriter.getType<Torch::ValueTensorType>(
+            boxesTensorType.getSizes().slice(1), boxesTensorType.getDtype());
+        auto scoreSlicedType = rewriter.getType<Torch::ValueTensorType>(
+            scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype());
+
         if (centerPointBox == 1) {
           // When center_point_box is 1, the box data is supplied as
           // [[x_center, y_center, width, height], ...]. Slice it to
           // [[x_center, y_center], ...] and [[width, height], ...],
           // calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate
           // to [[x1, y1, x2, y2], ...]
-          auto boxesTensorType =
-              dyn_cast<Torch::ValueTensorType>(boxes.getType());
-          Value const0 = rewriter.create<Torch::ConstantIntOp>(
-              loc, rewriter.getI64IntegerAttr(0));
-          Value const1 = rewriter.create<Torch::ConstantIntOp>(
-              loc, rewriter.getI64IntegerAttr(1));
-          Value const2 = rewriter.create<Torch::ConstantIntOp>(
-              loc, rewriter.getI64IntegerAttr(2));
-          Value const4 = rewriter.create<Torch::ConstantIntOp>(
-              loc, rewriter.getI64IntegerAttr(4));
-          Value const2F = rewriter.create<Torch::ConstantFloatOp>(
-              loc, rewriter.getF64FloatAttr(2.0));
 
           // extract scaled ranges for regions of interest
-          auto sliceShape = SmallVector<int64_t>{Torch::kUnknownSize, 2};
+          auto sliceShape =
+              SmallVector<int64_t>{Torch::kUnknownSize, Torch::kUnknownSize, 2};
           auto sliceTensorType = rewriter.getType<Torch::ValueTensorType>(
               sliceShape, boxesTensorType.getDtype());
+
+          // Boxes have shape [BxNx4]
           Value centers = rewriter.create<Torch::AtenSliceTensorOp>(
-              loc, sliceTensorType, boxes, const1, const0, const2, const1);
+              loc, sliceTensorType, boxes, cst2, cst0, cst2, cst1);
           Value sizes = rewriter.create<Torch::AtenSliceTensorOp>(
-              loc, sliceTensorType, boxes, const1, const2, const4, const1);
+              loc, sliceTensorType, boxes, cst2, cst2, cst4, cst1);
           Value halfSizes = rewriter.create<Torch::AtenDivScalarOp>(
-              loc, sizes.getType(), sizes, const2F);
+              loc, sizes.getType(), sizes, cst2F);
           Value x1y1s = rewriter.create<Torch::AtenSubTensorOp>(
-              loc, centers.getType(), centers, halfSizes, const1);
+              loc, centers.getType(), centers, halfSizes, cst1);
           Value x2y2s = rewriter.create<Torch::AtenAddTensorOp>(
-              loc, centers.getType(), centers, halfSizes, const1);
+              loc, centers.getType(), centers, halfSizes, cst1);
 
           Type listElemType = boxesTensorType.getWithSizesAndDtype(
               /*optionalSizes=*/std::nullopt,
@@ -3766,7 +3756,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
           Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
               loc, listType, SmallVector<Value>{x1y1s, x2y2s});
           boxes = rewriter.create<Torch::AtenCatOp>(loc, boxesTensorType,
-                                                    tensorList, const1);
+                                                    tensorList, cst2);
         }
 
         // TODO: Support score_threshold input
@@ -3792,10 +3782,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
         }
 
         // Get max_output_boxes_per_class and iou_threshold
-        Value cst0 = rewriter.create<Torch::ConstantIntOp>(
-            loc, rewriter.getI64IntegerAttr(0));
-        Value cst1 = rewriter.create<Torch::ConstantIntOp>(
-            loc, rewriter.getI64IntegerAttr(1));
         Value maxOutputBoxesPerClass = cst0;
         Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
             loc, rewriter.getF64FloatAttr(0.0));
@@ -3810,87 +3796,207 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
               loc, rewriter.getType<Torch::IntType>(), operands[2]);
         }
 
+        // Since the shape of boxes is [BxNx4] in Onnx and torchvision expects
+        // it to be of shape [Nx4], loop over the batch dimension. Similarly,
+        // for the scores tensor which has shape [BxCxN] in Onnx and torchvision
+        // expects it to be of shape [N], loop over the class dimension too.
+        auto numBatches =
+            rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst0);
+        auto numClasses =
+            rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst1);
+
+        // Create an empty tensor of shape (B*C*N, 3) to store the final result.
+        // We slice this to required elements at the end
+
+        Value numResults = rewriter.create<Torch::AtenMulIntOp>(
+            loc, numClasses.getType(), numBatches, numClasses);
+        numResults = rewriter.create<Torch::AtenMulIntOp>(
+            loc, numClasses.getType(), numResults, maxOutputBoxesPerClass);
+
+        auto intTy = rewriter.getType<Torch::IntType>();
+        auto intListTy = rewriter.getType<Torch::ListType>(intTy);
+
+        Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
+            loc, intListTy, SmallVector<Value>{numResults, cst3});
+        Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
+            loc, resultType, resultShapeList, /*dtype=*/cst4,
+            /*layout=*/cstNone,
+            /*device=*/cstNone, /*pinMemory=*/cstNone,
+            /*memoryFormat=*/cstNone);
+
         auto nmsTy = Torch::ValueTensorType::get(
             binder.op->getContext(), SmallVector<int64_t>{-1},
             rewriter.getIntegerType(64, /*signed=*/true));
-        Value result = rewriter.create<Torch::TorchvisionNmsOp>(
-            loc, nmsTy, boxes, scores, iouThreshold);
 
-        // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
-        Value numOutputBoxes =
-            rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
-        Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
-            loc, numOutputBoxes, maxOutputBoxesPerClass);
+        auto emptyTensorTy = rewriter.getType<Torch::ValueTensorType>(
+            SmallVector<int64_t>{}, nmsTy.getDtype());
 
-        auto nmsResultTy = Torch::ValueTensorType::get(
-            binder.op->getContext(),
-            SmallVector<int64_t>{resultType.getSizes()[0]},
-            rewriter.getIntegerType(64, /*signed=*/true));
-        auto ifSlice = rewriter.create<Torch::PrimIfOp>(
-            loc, TypeRange({nmsResultTy}), boxesCond);
+        auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>(
+            loc, TypeRange({resultType, intTy, intTy}), numBatches, cstTrue,
+            ValueRange({finalResult, /*Index to finalResult*/ cst0,
+                        /*Num values in result*/ cst0}));
         {
+          // Batch loop body
           PatternRewriter::InsertionGuard guard(rewriter);
-          rewriter.createBlock(&ifSlice.getThenRegion(),
-                               ifSlice.getThenRegion().begin());
+          Block *batchLoopBody = rewriter.createBlock(
+              &nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(),
+              TypeRange({intTy, resultType, intTy, intTy}),
+              {loc, loc, loc, loc});
+
+          auto batchIV = batchLoopBody->getArgument(0);
+          auto currRes = batchLoopBody->getArgument(1);
+          auto finalResIdx = batchLoopBody->getArgument(2);
+          auto numResultValues = batchLoopBody->getArgument(3);
+
+          auto boxValue = rewriter.create<Torch::AtenSelectIntOp>(
+              loc, boxSlicedType, boxes, cst0, batchIV);
+          auto batchValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
+              loc, emptyTensorTy, batchIV);
+
+          auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
+              loc, scoreSlicedType, scores, cst0, batchIV);
+          auto scoreSelectType =
+              cast<Torch::ValueTensorType>(scoreSelect.getType());
+          auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
+              scoreSelectType.getSizes().slice(1), scoreSelectType.getDtype());
+
+          auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>(
+              loc, TypeRange({resultType, intTy, intTy}), numClasses, cstTrue,
+              ValueRange({currRes, finalResIdx, numResultValues}));
 
-          Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
-              loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
-              /*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
-          rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
-        }
-        {
-          PatternRewriter::InsertionGuard guard(rewriter);
-          rewriter.createBlock(&ifSlice.getElseRegion(),
-                               ifSlice.getElseRegion().begin());
-
-          Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
-              loc, nmsResultTy, result);
-          rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
-        }
-        result = ifSlice.getResult(0);
-
-        // The result generated by torchvision.nms op is of shape [n], while the
-        // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
-        // and make it of shape [n, 1] and then concatenate it with a zero
-        // tensor of shape [n, 2] to make it of shape [n, 3].
-        FailureOr<Value> unsqueezedResult =
-            Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
-        if (failed(unsqueezedResult))
-          return rewriter.notifyMatchFailure(
-              binder.op, "failed to  unsqueeze result tensor");
-        result = unsqueezedResult.value();
-
-        numOutputBoxes =
-            rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
-        SmallVector<Value> zerosShapeValues{numOutputBoxes};
-        zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
-            loc, rewriter.getI64IntegerAttr(2)));
-        Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
-            loc,
-            rewriter.getType<Torch::ListType>(
-                rewriter.getType<Torch::IntType>()),
-            zerosShapeValues);
-        std::optional<ArrayRef<int64_t>> resultShape =
-            cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
-        if (!resultShape.has_value())
-          return rewriter.notifyMatchFailure(
-              binder.op, "expected result tensor to have shape");
-        llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
-        auto zerosTy = Torch::ValueTensorType::get(
-            resultType.getContext(), zerosShape, resultType.getOptionalDtype());
-        Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
-        Value zeros = rewriter.create<Torch::AtenZerosOp>(
-            loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);
-
-        Type listElemType =
-            cast<Torch::BaseTensorType>(resultType)
-                .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
-                                      /*optionalDtype=*/nullptr);
-        Type listType = Torch::ListType::get(listElemType);
-        Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
-            loc, listType, SmallVector<Value>{zeros, result});
-        rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
-                                                      tensorList, cst1);
+          {
+            // Class loop body
+            PatternRewriter::InsertionGuard guard(rewriter);
+            Block *classLoopBody = rewriter.createBlock(
+                &nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(),
+                TypeRange({intTy, resultType, intTy, intTy}),
+                {loc, loc, loc, loc});
+
+            auto classIV = classLoopBody->getArgument(0);
+            auto currRes = classLoopBody->getArgument(1);
+            auto finalResIdx = classLoopBody->getArgument(2);
+            Value numResultValues = classLoopBody->getArgument(3);
+
+            auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>(
+                loc, scoreValueType, scoreSelect, cst0, classIV);
+            auto classValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
+                loc, emptyTensorTy, classIV);
+
+            // TorchVision NMS
+            Value result = rewriter.create<Torch::TorchvisionNmsOp>(
+                loc, nmsTy, boxValue, scoreValue, iouThreshold);
+
+            // Compute NumOutputBoxes
+            Value numOutputBoxes =
+                rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
+            numOutputBoxes = rewriter.create<Torch::PrimNumToTensorScalarOp>(
+                loc, emptyTensorTy, numOutputBoxes);
+            Value maxBoxesPerClass =
+                rewriter.create<Torch::PrimNumToTensorScalarOp>(
+                    loc, emptyTensorTy, maxOutputBoxesPerClass);
+            auto minVal = rewriter.create<Torch::AtenMinimumOp>(
+                loc, emptyTensorTy, numOutputBoxes, maxBoxesPerClass);
+            numOutputBoxes =
+                rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal);
+
+            // Loop through the nms result
+            // The resulting shape of torchvision nms op is [num_selected] while
+            // that of onnx is [num_selected, 3] where the selected format is
+            // [batch_index, class_index, box_index].
+            // Insert the triplet [batch_index, class_index, box_index] into
+            // `finalResult` element by element for each box.
+
+            // TODO:: This can be simplified by concatinating the result of nms
+            // with that of tensors filled with batch and class indices instead
+            // of using the below loop. Currently this approach results in
+            // failures while lowering due to dynamic dims
+
+            auto nmsLoop = rewriter.create<Torch::PrimLoopOp>(
+                loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue,
+                ValueRange({currRes, finalResIdx}));
+            {
+              PatternRewriter::InsertionGuard guard(rewriter);
+              Block *loopBody = rewriter.createBlock(
+                  &nmsLoop.getRegion(), nmsLoop.getRegion().begin(),
+                  TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
+              auto iter = loopBody->getArgument(0);
+              auto currRes = loopBody->getArgument(1);
+              auto idxCst = loopBody->getArgument(2);
+
+              auto outputTensorSliceType =
+                  rewriter.getType<Torch::ValueTensorType>(
+                      SmallVector<int64_t>{3}, nmsTy.getDtype());
+
+              // Update batch dimension
+              auto batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, outputTensorSliceType, currRes, cst0, idxCst);
+              auto batchSelect = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, emptyTensorTy, batchDim3D, cst0, cst0);
+              auto bCopy = rewriter.create<Torch::AtenCopyOp>(
+                  loc, batchSelect.getType(), batchSelect, batchValue,
+                  cstFalse);
+              batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, outputTensorSliceType, currRes, cst0, idxCst);
+              auto scatterBatch = rewriter.create<Torch::AtenSelectScatterOp>(
+                  loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
+              auto batchResult = rewriter.create<Torch::AtenSelectScatterOp>(
+                  loc, resultType, currRes, scatterBatch, cst0, idxCst);
+
+              // Update class dimension
+              auto classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, outputTensorSliceType, batchResult, cst0, idxCst);
+              auto classSelect = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, emptyTensorTy, classDim3D, cst0, cst1);
+              auto cCopy = rewriter.create<Torch::AtenCopyOp>(
+                  loc, classSelect.getType(), classSelect, classValue,
+                  cstFalse);
+              classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, outputTensorSliceType, batchResult, cst0, idxCst);
+              auto scatterClass = rewriter.create<Torch::AtenSelectScatterOp>(
+                  loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
+              auto classRes = rewriter.create<Torch::AtenSelectScatterOp>(
+                  loc, resultType, batchResult, scatterClass, cst0, idxCst);
+
+              // Update nms result dimension
+              auto resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, outputTensorSliceType, classRes, cst0, idxCst);
+              auto resSelect = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, emptyTensorTy, resDim3D, cst0, cst2);
+              auto nmsResultValue = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, emptyTensorTy, result, cst0, iter);
+              auto rCopy = rewriter.create<Torch::AtenCopyOp>(
+                  loc, resSelect.getType(), resSelect, nmsResultValue,
+                  cstFalse);
+              resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
+                  loc, outputTensorSliceType, classRes, cst0, idxCst);
+              auto scatterRes = rewriter.create<Torch::AtenSelectScatterOp>(
+                  loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
+              Value nmsResult = rewriter.create<Torch::AtenSelectScatterOp>(
+                  loc, resultType, classRes, scatterRes, cst0, idxCst);
+
+              // Increment the result index
+              Value next =
+                  rewriter.create<Torch::AtenAddIntOp>(loc, idxCst, cst1);
+              rewriter.create<Torch::PrimLoopConditionOp>(
+                  loc, cstTrue, ValueRange({nmsResult, next}));
+            }
+            // Update the num result values
+            numResultValues = rewriter.create<Torch::AtenAddIntOp>(
+                loc, numResultValues, numOutputBoxes);
+            rewriter.create<Torch::PrimLoopConditionOp>(
+                loc, cstTrue,
+                ValueRange({nmsLoop.getResult(0), nmsLoop.getResult(1),
+                            numResultValues}));
+          }
+          rewriter.create<Torch::PrimLoopConditionOp>(
+              loc, cstTrue,
+              ValueRange({nmsClassLoop.getResult(0), nmsClassLoop.getResult(1),
+                          nmsClassLoop->getResult(2)}));
+        }
+        // Slice the result to required number of elements
+        rewriter.replaceOpWithNewOp<Torch::AtenSliceTensorOp>(
+            binder.op, resultType, nmsBatchLoop.getResult(0), cst0, cst0,
+            nmsBatchLoop.getResult(2), cst1);
         return success();
       });
 }
diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
index b2c718bceace..e4ec736c2a8e 100644
--- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
+++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
@@ -2034,53 +2034,25 @@ func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtens
 // CHECK-SAME:                                                      %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
 // CHECK-SAME:                                                      %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
 func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4],f32>, %arg1: !torch.vtensor<[1,1,10],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
-  // CHECK:           %[[VAL_5:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_6:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.vtensor<[10,4],f32>
-  // CHECK:           %[[VAL_10:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_11:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
-  // CHECK:           %[[VAL_15:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_16:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32>
-  // CHECK:           %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
-  // CHECK:           %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[],f32>
-  // CHECK:           %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float
-  // CHECK:           %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
-  // CHECK:           %[[VAL_24:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_25:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_26:.*]] = torch.constant.float 0.000000e+00
-  // CHECK:           %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float
-  // CHECK:           %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int
-  // CHECK:           %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64>
-  // CHECK:           %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>)
-  // CHECK:             %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
-  // CHECK:             torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64>
-  // CHECK:           } else {
-  // CHECK:             %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
-  // CHECK:             torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64>
-  // CHECK:           }
-  // CHECK:           %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
-  // CHECK:           %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_35:.*]] = torch.constant.int 2
-  // CHECK:           %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list<int>
-  // CHECK:           %[[VAL_37:.*]] = torch.constant.none
-  // CHECK:           %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
-  // CHECK:           %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
-  // CHECK:           %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
-  // CHECK:           return %[[VAL_40]] : !torch.vtensor<[1,3],si64>
+// CHECK:           %[[RES:.*]] = torch.aten.empty.memory_format
+// CHECK:           %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop
+// CHECK:             %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop
+// CHECK:               %[[NMS:.*]] = torch.torchvision.nms
+// CHECK:               %[[MIN_RES:.*]] = torch.aten.minimum
+// CHECK:               %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int
+// CHECK:               %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]]
+// CHECK:                 %[[SEL_1:.*]] = torch.aten.select.int
+// CHECK:                 %[[SEL_2:.*]] = torch.aten.select.int
+// CHECK:                 %[[COPY:.*]] = torch.aten.copy
+// CHECK:                 %[[SEL_3:.*]] = torch.aten.select.int
+// CHECK-COUNT-6:         torch.aten.select_scatter
+// CHECK:                 %[[ADD_INDEX:.*]] = torch.aten.add.int
+// CHECK:                 torch.prim.Loop.condition
+// CHECK:                 %[[ADD_INDEX_1:.*]] = torch.aten.add.int
+// CHECK:               torch.prim.Loop.condition
+// CHECK:             torch.prim.Loop.condition
+// CHECK:           %[[SLICE_RES:.*]] = torch.aten.slice.Tensor
+// CHECK:           return %[[SLICE_RES]] : !torch.vtensor<[1,3],si64>
   %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
   return %0 : !torch.vtensor<[1,3],si64>
 }
@@ -2094,53 +2066,25 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4]
 // CHECK-SAME:                                                 %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
 // CHECK-SAME:                                                 %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""}
 func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
-  // CHECK:           %[[VAL_5:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_6:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
-  // CHECK:           %[[VAL_10:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_11:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
-  // CHECK:           %[[VAL_15:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_16:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
-  // CHECK:           %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
-  // CHECK:           %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32>
-  // CHECK:           %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float
-  // CHECK:           %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
-  // CHECK:           %[[VAL_24:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_25:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_26:.*]] = torch.constant.float 0.000000e+00
-  // CHECK:           %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float
-  // CHECK:           %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int
-  // CHECK:           %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64>
-  // CHECK:           %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>)
-  // CHECK:             %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
-  // CHECK:             torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64>
-  // CHECK:           } else {
-  // CHECK:             %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
-  // CHECK:             torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64>
-  // CHECK:           }
-  // CHECK:           %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
-  // CHECK:           %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_35:.*]] = torch.constant.int 2
-  // CHECK:           %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list<int>
-  // CHECK:           %[[VAL_37:.*]] = torch.constant.none
-  // CHECK:           %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
-  // CHECK:           %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
-  // CHECK:           %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
-  // CHECK:           return %[[VAL_40]] : !torch.vtensor<[1,3],si64>
+// CHECK:           %[[RES:.*]] = torch.aten.empty.memory_format
+// CHECK:           %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop
+// CHECK:             %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop
+// CHECK:               %[[NMS:.*]] = torch.torchvision.nms
+// CHECK:               %[[MIN_RES:.*]] = torch.aten.minimum
+// CHECK:               %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int
+// CHECK:               %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]]
+// CHECK:                 %[[SEL_1:.*]] = torch.aten.select.int
+// CHECK:                 %[[SEL_2:.*]] = torch.aten.select.int
+// CHECK:                 %[[COPY:.*]] = torch.aten.copy
+// CHECK:                 %[[SEL_3:.*]] = torch.aten.select.int
+// CHECK-COUNT-6:         torch.aten.select_scatter
+// CHECK:                 %[[ADD_INDEX:.*]] = torch.aten.add.int
+// CHECK:                 torch.prim.Loop.condition
+// CHECK:                 %[[ADD_INDEX_1:.*]] = torch.aten.add.int
+// CHECK:               torch.prim.Loop.condition
+// CHECK:             torch.prim.Loop.condition
+// CHECK:           %[[SLICE_RES:.*]] = torch.aten.slice.Tensor
+// CHECK:           return %[[SLICE_RES]] : !torch.vtensor<[1,3],si64>
   %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
   return %0 : !torch.vtensor<[1,3],si64>
 }
@@ -2152,68 +2096,82 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
 // CHECK-SAME:                                                       %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
 // CHECK-SAME:                                                       %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
 func.func @test_nonmaxsuppression_center_point_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
-  // CHECK:           %[[VAL_5:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_6:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
-  // CHECK:           %[[VAL_10:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_11:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
-  // CHECK:           %[[VAL_15:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_16:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
-  // CHECK:           %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
-  // CHECK:           %[[VAL_20:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_21:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_22:.*]] = torch.constant.int 2
-  // CHECK:           %[[VAL_23:.*]] = torch.constant.int 4
-  // CHECK:           %[[VAL_24:.*]] = torch.constant.float 2.000000e+00
-  // CHECK:           %[[VAL_25:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_20]], %[[VAL_22]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
-  // CHECK:           %[[VAL_26:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
-  // CHECK:           %[[VAL_27:.*]] = torch.aten.div.Scalar %[[VAL_26]], %[[VAL_24]] : !torch.vtensor<[?,2],f32>, !torch.float -> !torch.vtensor<[?,2],f32>
-  // CHECK:           %[[VAL_28:.*]] = torch.aten.sub.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
-  // CHECK:           %[[VAL_29:.*]] = torch.aten.add.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
-  // CHECK:           %[[VAL_30:.*]] = torch.prim.ListConstruct %[[VAL_28]], %[[VAL_29]] : (!torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>) -> !torch.list<vtensor>
-  // CHECK:           %[[VAL_31:.*]] = torch.aten.cat %[[VAL_30]], %[[VAL_21]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,4],f32>
-  // CHECK:           %[[VAL_32:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
-  // CHECK:           %[[VAL_33:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32>
-  // CHECK:           %[[VAL_34:.*]] = torch.aten.item %[[VAL_33]] : !torch.vtensor<[],f32> -> !torch.float
-  // CHECK:           %[[VAL_35:.*]] = torch.aten.ge.float %[[VAL_34]], %[[VAL_32]] : !torch.float, !torch.float -> !torch.bool
-  // CHECK:           torch.runtime.assert %[[VAL_35]], "unimplemented: score_threshold should be <= min(scores)"
-  // CHECK:           %[[VAL_36:.*]] = torch.constant.int 0
-  // CHECK:           %[[VAL_37:.*]] = torch.constant.int 1
-  // CHECK:           %[[VAL_38:.*]] = torch.constant.float 0.000000e+00
-  // CHECK:           %[[VAL_39:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
-  // CHECK:           %[[VAL_40:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
-  // CHECK:           %[[VAL_41:.*]] = torch.torchvision.nms %[[VAL_31]], %[[VAL_19]], %[[VAL_39]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64>
-  // CHECK:           %[[VAL_42:.*]] = torch.aten.size.int %[[VAL_41]], %[[VAL_36]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_43:.*]] = torch.aten.gt.int %[[VAL_42]], %[[VAL_40]] : !torch.int, !torch.int -> !torch.bool
-  // CHECK:           %[[VAL_44:.*]] = torch.prim.If %[[VAL_43]] -> (!torch.vtensor<[1],si64>) {
-  // CHECK:             %[[VAL_45:.*]] = torch.aten.slice.Tensor %[[VAL_41]], %[[VAL_36]], %[[VAL_36]], %[[VAL_40]], %[[VAL_37]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
-  // CHECK:             torch.prim.If.yield %[[VAL_45]] : !torch.vtensor<[1],si64>
-  // CHECK:           } else {
-  // CHECK:             %[[VAL_46:.*]] = torch.tensor_static_info_cast %[[VAL_41]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
-  // CHECK:             torch.prim.If.yield %[[VAL_46]] : !torch.vtensor<[1],si64>
-  // CHECK:           }
-  // CHECK:           %[[VAL_47:.*]] = torch.aten.unsqueeze %[[VAL_44]], %[[VAL_37]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
-  // CHECK:           %[[VAL_48:.*]] = torch.aten.size.int %[[VAL_47]], %[[VAL_36]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
-  // CHECK:           %[[VAL_49:.*]] = torch.constant.int 2
-  // CHECK:           %[[VAL_50:.*]] = torch.prim.ListConstruct %[[VAL_48]], %[[VAL_49]] : (!torch.int, !torch.int) -> !torch.list<int>
-  // CHECK:           %[[VAL_51:.*]] = torch.constant.none
-  // CHECK:           %[[VAL_52:.*]] = torch.aten.zeros %[[VAL_50]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
-  // CHECK:           %[[VAL_53:.*]] = torch.prim.ListConstruct %[[VAL_52]], %[[VAL_47]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
-  // CHECK:           %[[VAL_54:.*]] = torch.aten.cat %[[VAL_53]], %[[VAL_37]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
-  // CHECK:           return %[[VAL_54]] : !torch.vtensor<[1,3],si64>
+// CHECK:           %[[VAL_5:.*]] = torch.constant.int 0
+// CHECK:           %[[VAL_6:.*]] = torch.constant.int 1
+// CHECK:           %[[VAL_7:.*]] = torch.constant.int 2
+// CHECK:           %[[VAL_8:.*]] = torch.constant.int 3
+// CHECK:           %[[VAL_9:.*]] = torch.constant.int 4
+// CHECK:           %[[VAL_10:.*]] = torch.constant.float 2.000000e+00
+// CHECK:           %[[VAL_11:.*]] = torch.constant.none
+// CHECK:           %[[VAL_12:.*]] = torch.constant.bool true
+// CHECK:           %[[VAL_13:.*]] = torch.constant.bool false
+// CHECK:           %[[VAL_14:.*]] = torch.aten.slice.Tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_5]], %[[VAL_7]], %[[VAL_6]] : !torch.vtensor<[1,1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,2],f32>
+// CHECK:           %[[VAL_15:.*]] = torch.aten.slice.Tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_7]], %[[VAL_9]], %[[VAL_6]] : !torch.vtensor<[1,1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,2],f32>
+// CHECK:           %[[VAL_16:.*]] = torch.aten.div.Scalar %[[VAL_15]], %[[VAL_10]] : !torch.vtensor<[?,?,2],f32>, !torch.float -> !torch.vtensor<[?,?,2],f32>
+// CHECK:           %[[VAL_17:.*]] = torch.aten.sub.Tensor %[[VAL_14]], %[[VAL_16]], %[[VAL_6]] : !torch.vtensor<[?,?,2],f32>, !torch.vtensor<[?,?,2],f32>, !torch.int -> !torch.vtensor<[?,?,2],f32>
+// CHECK:           %[[VAL_18:.*]] = torch.aten.add.Tensor %[[VAL_14]], %[[VAL_16]], %[[VAL_6]] : !torch.vtensor<[?,?,2],f32>, !torch.vtensor<[?,?,2],f32>, !torch.int -> !torch.vtensor<[?,?,2],f32>
+// CHECK:           %[[VAL_19:.*]] = torch.prim.ListConstruct %[[VAL_17]], %[[VAL_18]] : (!torch.vtensor<[?,?,2],f32>, !torch.vtensor<[?,?,2],f32>) -> !torch.list<vtensor>
+// CHECK:           %[[VAL_20:.*]] = torch.aten.cat %[[VAL_19]], %[[VAL_7]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,1,4],f32>
+// CHECK:           %[[VAL_21:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
+// CHECK:           %[[VAL_22:.*]] = torch.aten.min %[[VAL_1]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[],f32>
+// CHECK:           %[[VAL_23:.*]] = torch.aten.item %[[VAL_22]] : !torch.vtensor<[],f32> -> !torch.float
+// CHECK:           %[[VAL_24:.*]] = torch.aten.ge.float %[[VAL_23]], %[[VAL_21]] : !torch.float, !torch.float -> !torch.bool
+// CHECK:           torch.runtime.assert %[[VAL_24]], "unimplemented: score_threshold should be <= min(scores)"
+// CHECK:           %[[RES:.*]] = torch.aten.empty.memory_format
+// CHECK:           %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop
+// CHECK:             %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop
+// CHECK:               %[[NMS:.*]] = torch.torchvision.nms
+// CHECK:               %[[MIN_RES:.*]] = torch.aten.minimum
+// CHECK:               %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int
+// CHECK:               %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]]
+// CHECK:                 %[[SEL_1:.*]] = torch.aten.select.int
+// CHECK:                 %[[SEL_2:.*]] = torch.aten.select.int
+// CHECK:                 %[[COPY:.*]] = torch.aten.copy
+// CHECK:                 %[[SEL_3:.*]] = torch.aten.select.int
+// CHECK-COUNT-6:         torch.aten.select_scatter
+// CHECK:                 %[[ADD_INDEX:.*]] = torch.aten.add.int
+// CHECK:                 torch.prim.Loop.condition
+// CHECK:                 %[[ADD_INDEX_1:.*]] = torch.aten.add.int
+// CHECK:               torch.prim.Loop.condition
+// CHECK:             torch.prim.Loop.condition
+// CHECK:           %[[SLICE_RES:.*]] = torch.aten.slice.Tensor
+// CHECK:           return %[[SLICE_RES]] : !torch.vtensor<[1,3],si64>
   %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 1 : si64} : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
   return %0 : !torch.vtensor<[1,3],si64>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @test_nonmaxsuppression_multiple_batch_class(
+// CHECK-SAME:                                                           %[[VAL_0:.*]]: !torch.vtensor<[3,8,4],f32>,
+// CHECK-SAME:                                                           %[[VAL_1:.*]]: !torch.vtensor<[3,5,8],f32>,
+// CHECK-SAME:                                                           %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
+// CHECK-SAME:                                                           %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
+// CHECK-SAME:                                                           %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?,3],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
+func.func @test_nonmaxsuppression_multiple_batch_class(%arg0: !torch.vtensor<[3,8,4],f32>, %arg1: !torch.vtensor<[3,5,8],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?,3],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
+// CHECK:           %[[RES:.*]] = torch.aten.empty.memory_format
+// CHECK:           %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop
+// CHECK:             %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop
+// CHECK:               %[[NMS:.*]] = torch.torchvision.nms
+// CHECK:               %[[MIN_RES:.*]] = torch.aten.minimum
+// CHECK:               %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int
+// CHECK:               %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]]
+// CHECK:                 %[[SEL_1:.*]] = torch.aten.select.int
+// CHECK:                 %[[SEL_2:.*]] = torch.aten.select.int
+// CHECK:                 %[[COPY:.*]] = torch.aten.copy
+// CHECK:                 %[[SEL_3:.*]] = torch.aten.select.int
+// CHECK-COUNT-6:         torch.aten.select_scatter
+// CHECK:                 %[[ADD_INDEX:.*]] = torch.aten.add.int
+// CHECK:                 torch.prim.Loop.condition
+// CHECK:                 %[[ADD_INDEX_1:.*]] = torch.aten.add.int
+// CHECK:               torch.prim.Loop.condition
+// CHECK:             torch.prim.Loop.condition
+// CHECK:           %[[SLICE_RES:.*]] = torch.aten.slice.Tensor
+// CHECK:           return %[[SLICE_RES]] : !torch.vtensor<[?,3],si64>
+  %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 0 : si64} : (!torch.vtensor<[3,8,4],f32>, !torch.vtensor<[3,5,8],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[?,3],si64>
+  return %0 : !torch.vtensor<[?,3],si64>
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @test_mwm