From a0232e9ebd403fefe57325562e5d129e9e60643e Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 16 Apr 2024 12:24:46 +0530 Subject: [PATCH] [MLIR][TORCH] Add OnnxToTorch lowering for ReduceL1 Op (#3146) Adds OnnxToTorch Lowering for the ReduceL1 op. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 268 ++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 9 +- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 53 ++++ 3 files changed, 213 insertions(+), 117 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d5b4971f9e15..b4bd102f152f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -39,6 +39,127 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, return rewriter.create(binder.getLoc(), rewriter.getType(), ofItem); } + +// In case the ReduceSum Op was not the first operation performed on the data, +// we provide the original operand through storeResult, which will be modified +// if the result will be passed onto another operation, and will be used for +// noop_with_empty_axes handling before that. +LogicalResult reducedSumImpl(OpBinder binder, + ConversionPatternRewriter &rewriter, Value data, + Torch::ValueTensorType resultType, + Value &storeResult, int64_t keepDims, + int64_t noop_with_empty_axes, + bool isIntermediateOp) { + + SmallVector axesList; + Value axesVal; + if (!binder.tensorOperandAtIndex(axesVal, 1)) { + auto inputType = data.getType().dyn_cast(); + if (!inputType.hasSizes() || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected input and result to have shapes"); + } + + if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { + SmallVector inputShape{inputType.getSizes()}; + SmallVector resultShape{resultType.getSizes()}; + // if the shapes are equal, none of the dims is reduced + if (llvm::equal(inputShape, resultShape)) { + // simply fill in the op and return + rewriter.replaceOp(binder.op, data); + return success(); + } + if (areAllElementsDistinct(inputShape)) { + // The check for the input shape elements to be distinct is added + // for the cases like: + // Input: [3, 2, 2] -> Output: [3, 2] + // For the above case, from the input and output shape it can't be + // inferred whether the dim:1 is reduced or dim:2. To avoid these + // type of cases, the check has been placed. + SmallVector reduceDims; + unsigned resultShapeCounter = 0; + for (unsigned i = 0; i < inputShape.size(); i++) { + if (resultShapeCounter < resultShape.size() && + inputShape[i] == resultShape[resultShapeCounter]) { + resultShapeCounter++; + } else { + reduceDims.push_back(i); + if (resultShapeCounter < resultShape.size() && + resultShape[resultShapeCounter] == 1) + resultShapeCounter++; + } + } + for (auto i : reduceDims) { + axesList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } + } + if (axesList.empty()) { + Torch::BaseTensorType axesType = + axesVal.getType().cast(); + auto axesTy = dyn_cast(axesVal.getType()); + auto axesShape = axesTy.getSizes(); + if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + return failure(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + SmallVector selectSizes{1}; + auto selType = rewriter.getType( + selectSizes, axesType.getOptionalDtype()); + int64_t numAxes = axesShape[0]; + for (int64_t i = 0; i < numAxes; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + binder.getLoc(), selType, axesVal, zero, iv); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + } + + SmallVector axesInts; + if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { + for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axesInts[i])); + axesList.push_back(iv); + } + } + + // Do not include absolute value in the noop + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, storeResult); + return success(); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + Value dType = rewriter.create(binder.getLoc()); + // If we are using the ReducedSum as an intermediate op to be passed into + // another operation, we might not want to replace the Op. So we create a new + // Op and store the result in a variable. + if (!isIntermediateOp) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/dType); + } else { + storeResult = rewriter.create( + binder.getLoc(), resultType, data, dimValueList, keepDimBool, + /*dtype=*/dType); + } + return success(); +} } // namespace void mlir::torch::onnx_c::populateDefaultDomainQtoZ( @@ -758,124 +879,41 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); - patterns.onOp( - "ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - - SmallVector axesList; - - Value axesVal; - if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = data.getType().dyn_cast(); - if (!inputType.hasSizes() || !resultType.hasSizes()) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: expected input and result to have shapes"); - } - - // If the input shape and result shape is statically known then the - // list of dims to be squeezed can be derived from those shapes. As a - // result, we don't have to wait for the dim values to be known at - // runtime which is also expected by the downstream pipeline. - if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { - SmallVector inputShape{inputType.getSizes()}; - SmallVector resultShape{resultType.getSizes()}; - if (llvm::equal(inputShape, resultShape)) { - // Case: none of the dimension is reduced. - rewriter.replaceOp(binder.op, data); - return success(); - } - if (areAllElementsDistinct(inputShape)) { - // The check for the input shape elements to be distinct is added - // for the cases like: - // Input: [3, 2, 2] -> Output: [3, 2] - // For the above case, from the input and output shape it can't be - // inferred whether the dim:1 is reduced or dim:2. To avoid these - // type of cases, the check has been placed. - SmallVector reduceDims; - unsigned resultShapeCounter = 0; - for (unsigned i = 0; i < inputShape.size(); i++) { - if (resultShapeCounter < resultShape.size() && - inputShape[i] == resultShape[resultShapeCounter]) { - resultShapeCounter++; - } else { - reduceDims.push_back(i); - if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1) - resultShapeCounter++; - } - } - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - } - } - - if (axesList.empty()) { - Torch::BaseTensorType axesType = - axesVal.getType().cast(); - auto axesTy = dyn_cast(axesVal.getType()); - auto axesShape = axesTy.getSizes(); - if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) - return failure(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - SmallVector selectSizes{1}; - auto selType = rewriter.getType( - selectSizes, axesType.getOptionalDtype()); - int64_t numAxes = axesShape[0]; - for (int64_t i = 0; i < numAxes; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(i)); - Value extract = rewriter.create( - binder.getLoc(), selType, axesVal, zero, iv); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - } + patterns.onOp("ReduceL1", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t keepDims, noop_with_empty_axes; + Value operand; + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); - SmallVector axesInts; - if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(axesInts[i])); - axesList.push_back(iv); - } - } + Value data = rewriter.create( + binder.getLoc(), operand.getType(), operand); - // deal with case when axes is empty - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } + return reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/operand, keepDims, + noop_with_empty_axes, false); + }); + patterns.onOp("ReduceSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - Value noneVal = rewriter.create(binder.getLoc()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/noneVal); - return success(); - }); + return reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); + }); patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 69537798f124..54687aff33f5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2384,8 +2384,6 @@ "RandModule_basic", # Failure - onnx_lowering: onnx.ReduceL1 - "ReduceL1NormModule_basic", - "ReduceL1NormWithDTypeModule_basic", "ReduceL1NormComplexModule_basic", # Failure - onnx_lowering: onnx.ReduceL2 @@ -2529,6 +2527,13 @@ } +if torch_version_for_comparison() >= version.parse("2.4.0.dev"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # ERROR: Found dtype (torch.float64) but expected (torch.float32) + "ReduceL1NormWithDTypeModule_basic", + } + + ONNX_CRASHING_SET = { "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index a46654e45a8d..0fdecd68481e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -863,6 +863,59 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens return %0 : !torch.vtensor<[4],i1> } +// ----- + +// CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example +func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceL1"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l1_keep_dims_example +func.func @test_reduce_l1_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceL1"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_l1_do_not_keepdims_example +func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceL1"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0