diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index cf41bbcd711b..363b846aae04 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -116,6 +117,34 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, else division = b.createOrFold(loc, dividend, strideInt); Value out = b.createOrFold(loc, division, c1); + + if (ceilMode) { + Value outMinusOneTimesStride = + b.createOrFold(loc, division, strideInt); + Value inAddLeftPadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), paddingInt); + + auto reduceOutputDim = + b.createOrFold(loc, arith::CmpIPredicate::uge, + outMinusOneTimesStride, inAddLeftPadding); + + // Emit 'then' region of 'scf.if' + auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { + opBuilder.create(loc, division); + }; + + // Emit 'else' region of 'scf.if' + auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { + opBuilder.create(loc, out); + }; + + // Emit 'scf.if' op + auto ifOp = b.create(loc, reduceOutputDim, emitThenRegion, + emitElseRegion); + + return castIntToIndex(b, loc, ifOp.getResult(0)); + } + return castIntToIndex(b, loc, out); } @@ -527,8 +556,8 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, Type elementType) { auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { - Value elem = - convertScalarToDtype(builder, loc, payloadArgs[0], elementType); + Value elem = mlir::torch::Torch::convertScalarToDtype( + builder, loc, payloadArgs[0], elementType); builder.create(loc, elem); }; return torch_to_linalg::createElementwiseLinalgGeneric( diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e75d358b068d..824f1de07195 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5252,9 +5252,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { } else { int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; - if (ceilMode && (dimSize % stride != 0)) - return dimSize / stride + 2; - return dimSize / stride + 1; + int64_t outputDim = dimSize / stride + 1; + if (ceilMode && (dimSize % stride != 0) && + (outputDim * stride < inputDim + padBefore)) + outputDim++; + return outputDim; } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e0011b9a347e..15f73f87eec3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -763,6 +763,7 @@ "LenStrModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", "MaxPool2dWithIndicesBackwardStatic3DModule_basic", @@ -2261,6 +2262,7 @@ "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dStaticModule_basic", "MeanModule_basic", "MmDagModule_basic", @@ -3000,6 +3002,7 @@ "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", @@ -4516,6 +4519,7 @@ "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 84e0e2eb9cf5..e2eaa4cfd0fe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) +class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp2d = torch.nn.MaxPool2d( + kernel_size=6, + stride=6, + padding=3, + dilation=1, + ceil_mode=True, + ) + + @export + @annotate_args( + [ + None, + ([2, 6, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp2d(x) + + +@register_test_case( + module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule() +) +def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0)) + + # ==============================================================================