From 859f5d280f68bf6b5e5c0ac61159630ec9824217 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Fri, 12 Apr 2024 15:18:22 -0700 Subject: [PATCH] Generalize getting index for onnx compress op (#3150) --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 64 ++++++++----------- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 17 ++--- 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index f8829ec06b25..e35880909bc1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -721,35 +721,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); - // get indexs from the condition tensor - auto dtype = dyn_cast(conditionTensor.getType()) - .getDtype(); - auto constOp = dyn_cast( - conditionTensor.getDefiningOp()); - auto elementsAttr = - dyn_cast(constOp.getValueAttr()); - SmallVector apValues; - int64_t index = 0; - for (auto intAttr : elementsAttr.getValues()) { - int64_t i = dyn_cast(intAttr).getSInt(); - if (i) - apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), index)); - index++; - } - SmallVector indexShape = {static_cast(apValues.size())}; - auto indexType = Torch::ValueTensorType::get(binder.op->getContext(), - indexShape, dtype); - auto attr = DenseElementsAttr::get( - cast(RankedTensorType::get(indexShape, dtype)), - apValues); - Value indexVal = - rewriter.replaceOpWithNewOp( - constOp, indexType, attr); - auto shapeSizes = dyn_cast(operand.getType()).getSizes(); + auto resultSizes = resultType.getSizes(); + + // flatten input tensor if using default axis if (axis == INT64_MAX) { - // flatten input tensor if using default axis + SmallVector nonzeroShape = {resultSizes[0]}; + auto dtype = + dyn_cast(conditionTensor.getType()) + .getDtype(); + auto nonzeroType = + rewriter.getType(nonzeroShape, dtype); + Value indexVal = rewriter.create( + binder.getLoc(), nonzeroType, conditionTensor); Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstNegOne = rewriter.create( @@ -759,22 +744,29 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( numElements *= i; } SmallVector flattenShape = {numElements}; - auto flattenType = Torch::ValueTensorType::get( - binder.op->getContext(), flattenShape, resultType.getDtype()); + auto flattenType = rewriter.getType( + flattenShape, resultType.getDtype()); Value flattenTensor = rewriter.create( binder.getLoc(), flattenType, operand, cstZero, cstNegOne); rewriter.replaceOpWithNewOp( binder.op, resultType, flattenTensor, cstZero, indexVal); return success(); - } else { - if (axis < 0) - // Negative axis value means counting dimensions from the back - axis += shapeSizes.size(); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, dim, indexVal); } + + // Negative axis value means counting dimensions from the back + if (axis < 0) + axis += shapeSizes.size(); + SmallVector nonzeroShape = {resultSizes[axis]}; + auto dtype = dyn_cast(conditionTensor.getType()) + .getDtype(); + auto nonzeroType = + rewriter.getType(nonzeroShape, dtype); + Value indexVal = rewriter.create( + binder.getLoc(), nonzeroType, conditionTensor); + Value dimVal = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, dimVal, indexVal); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 2aaf2a6977a2..4a70656b9c85 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1709,12 +1709,11 @@ func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3, // ----- // CHECK-LABEL: func.func @test_compress -func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,2], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> +func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,2], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INDEX:.*]] = torch.aten.nonzero %arg1 : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64> // CHECK: %[[DIM:.*]] = torch.constant.int 2 - // CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32> - %cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> - %0 = torch.operator "onnx.Compress"(%arg0, %cst) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,2],f32> + // CHECK: torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,3,2],f32> + %0 = torch.operator "onnx.Compress"(%arg0, %arg1) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,2],f32> return %0 : !torch.vtensor<[2,3,2],f32> } @@ -1722,11 +1721,12 @@ func.func @test_compress(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[ // CHECK-LABEL: func.func @test_compress_default_axis func.func @test_compress_default_axis(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 3, 5]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 0, 1, 0, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64> + // CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[6],si64> -> !torch.vtensor<[3],si64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[END_DIM:.*]] = torch.constant.int -1 // CHECK: %[[ATEN_FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0]], %[[END_DIM]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32> - // CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.index_select %[[ATEN_FLATTEN]], %[[INT0]], %[[INDEX]] : !torch.vtensor<[6],f32>, !torch.int, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],f32> %cst = torch.vtensor.literal(dense<[0,1,0,1,0,1]> : tensor<6xsi64>) : !torch.vtensor<[6], si64> %0 = torch.operator "onnx.Compress"(%arg0, %cst) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> @@ -1736,7 +1736,8 @@ func.func @test_compress_default_axis(%arg0: !torch.vtensor<[2,3],f32>) -> !torc // CHECK-LABEL: func.func @test_compress_neg_axis func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,2,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INDEX:.*]] = torch.vtensor.literal(dense<[1, 2]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + // CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<[0, 1, 1]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INDEX:.*]] = torch.aten.nonzero %[[CST]] : !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64> // CHECK: %[[DIM:.*]] = torch.constant.int 1 // CHECK: %[[ATEN_INDEX_SELECT:.*]] = torch.aten.index_select %arg0, %[[DIM]], %[[INDEX]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,2,4],f32> %cst = torch.vtensor.literal(dense<[0,1,1]> : tensor<3xsi64>) : !torch.vtensor<[3], si64>