diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a18c0bae01fc..bd0215d302e7 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -403,17 +403,40 @@ class ConvertAtenReflectionPad2dOp int64_t vDim = numDims - 2; Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; + Type indexType = rewriter.getIndexType(); - assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && - "Left padding too large"); - assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && - "Right padding too large"); - assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && - "Top padding too large"); - assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && - "Bottom padding too large"); + auto leftPadAssertMsg = "Left padding too large"; + auto rightPadAssertMsg = "Right padding too large"; + auto topPadAssertMsg = "Top padding too large"; + auto bottomPadAssertMsg = "Bottom padding too large"; + + auto addPadDynAssert = [&](int64_t pad, Value dimSize, + const llvm::Twine &msg) { + Value padValue = getConstant(rewriter, loc, pad, indexType); + Value pred = rewriter.create( + loc, arith::CmpIPredicate::slt, padValue, dimSize); + rewriter.create(loc, pred, rewriter.getStringAttr(msg)); + }; + + if (inputType.getShape()[hDim] != ShapedType::kDynamic) { + assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && + leftPadAssertMsg); + assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && + rightPadAssertMsg); + } else { + addPadDynAssert(getHPadArgument(LEFT), hDimSize, leftPadAssertMsg); + addPadDynAssert(getHPadArgument(RIGHT), hDimSize, rightPadAssertMsg); + } + if (inputType.getShape()[vDim] != ShapedType::kDynamic) { + assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && + topPadAssertMsg); + assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && + bottomPadAssertMsg); + } else { + addPadDynAssert(getVPadArgument(TOP), vDimSize, topPadAssertMsg); + addPadDynAssert(getVPadArgument(BOTTOM), vDimSize, bottomPadAssertMsg); + } - Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index a97d7f09eda6..64d6e46a2a72 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -36,6 +36,29 @@ def ReflectionPad2dModule_basic(module, tu: TestUtils): # ============================================================================== +class ReflectionPad2dDynamicSizesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (10, 10, 10, 10)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dDynamicSizesModule()) +def ReflectionPad2dDynamicSizesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, low=-1)) + + +# ============================================================================== + + class ReflectionPad2dModuleTop(torch.nn.Module): def __init__(self): super().__init__()