From 4b17fb5fdd125bdc961ceab8e8ef608d8d3ad34f Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 3 Oct 2024 12:00:34 +0000 Subject: [PATCH 1/4] Fix AtenReflectionPad2dOp conversion to not assert when dimensions unknown --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5542e0fc642f..80fd9703f0b1 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -404,14 +404,18 @@ class ConvertAtenReflectionPad2dOp Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; - assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && - "Left padding too large"); - assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && - "Right padding too large"); - assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && - "Top padding too large"); - assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && - "Bottom padding too large"); + if (inputType.getShape()[hDim] != kUnknownSize) { + assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && + "Left padding too large"); + assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && + "Right padding too large"); + } + if (inputType.getShape()[vDim] != kUnknownSize) { + assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && + "Top padding too large"); + assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && + "Bottom padding too large"); + } Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); From 712f6e3473e272e55c4137273db797806ef38a7d Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:45:32 +0000 Subject: [PATCH 2/4] Use correct constant for dynamic dim --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 80fd9703f0b1..c1c2400475bc 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -404,13 +404,13 @@ class ConvertAtenReflectionPad2dOp Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; - if (inputType.getShape()[hDim] != kUnknownSize) { + if (inputType.getShape()[hDim] != ShapedType::kDynamic) { assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && "Left padding too large"); assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && "Right padding too large"); } - if (inputType.getShape()[vDim] != kUnknownSize) { + if (inputType.getShape()[vDim] != ShapedType::kDynamic) { assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && "Top padding too large"); assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && From a2973b0f2fddc21df3454010c6d66f1de9f954c1 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:54:03 +0000 Subject: [PATCH 3/4] Add ReflectionPad2d dynamic test --- .../torch_mlir_e2e_test/test_suite/padding.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index a97d7f09eda6..64d6e46a2a72 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -36,6 +36,29 @@ def ReflectionPad2dModule_basic(module, tu: TestUtils): # ============================================================================== +class ReflectionPad2dDynamicSizesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (10, 10, 10, 10)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dDynamicSizesModule()) +def ReflectionPad2dDynamicSizesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, low=-1)) + + +# ============================================================================== + + class ReflectionPad2dModuleTop(torch.nn.Module): def __init__(self): super().__init__() From e2b26e5b74babe14e4fe8796cec3edb9eaf05853 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:54:26 +0000 Subject: [PATCH 4/4] Add dynamic pad asserts --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 05fd12f16f17..bd0215d302e7 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -403,21 +403,40 @@ class ConvertAtenReflectionPad2dOp int64_t vDim = numDims - 2; Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; + Type indexType = rewriter.getIndexType(); + + auto leftPadAssertMsg = "Left padding too large"; + auto rightPadAssertMsg = "Right padding too large"; + auto topPadAssertMsg = "Top padding too large"; + auto bottomPadAssertMsg = "Bottom padding too large"; + + auto addPadDynAssert = [&](int64_t pad, Value dimSize, + const llvm::Twine &msg) { + Value padValue = getConstant(rewriter, loc, pad, indexType); + Value pred = rewriter.create( + loc, arith::CmpIPredicate::slt, padValue, dimSize); + rewriter.create(loc, pred, rewriter.getStringAttr(msg)); + }; if (inputType.getShape()[hDim] != ShapedType::kDynamic) { assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && - "Left padding too large"); + leftPadAssertMsg); assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && - "Right padding too large"); + rightPadAssertMsg); + } else { + addPadDynAssert(getHPadArgument(LEFT), hDimSize, leftPadAssertMsg); + addPadDynAssert(getHPadArgument(RIGHT), hDimSize, rightPadAssertMsg); } if (inputType.getShape()[vDim] != ShapedType::kDynamic) { assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && - "Top padding too large"); + topPadAssertMsg); assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && - "Bottom padding too large"); + bottomPadAssertMsg); + } else { + addPadDynAssert(getVPadArgument(TOP), vDimSize, topPadAssertMsg); + addPadDynAssert(getVPadArgument(BOTTOM), vDimSize, bottomPadAssertMsg); } - Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType);