diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1f3ff7ac2346..55ddc65b3a03 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } std::string padding; + int64_t paddingModeInt; if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) return rewriter.notifyMatchFailure(binder.op, "padding_mode bind failure"); - if (padding != "zeros") + if (padding == "zeros") { + paddingModeInt = 0; + } else if (padding == "border") { + paddingModeInt = 1; + } else { return rewriter.notifyMatchFailure( - binder.op, "currently only padding_mode : zeros supported"); + binder.op, + "currently only padding_mode : zeros and border supported"); + } int64_t align; if (binder.s64IntegerAttr(align, "align_corners", 0)) return rewriter.notifyMatchFailure(binder.op, @@ -156,8 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt)); Value paddingMode = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + binder.getLoc(), paddingModeInt); bool alignMode = align; Value alignCorners = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 35e4144f30eb..957937180236 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2568,10 +2568,30 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { return res; }; + auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x, + Value SizeSubOne) -> Value { + Value xMaxZero = b.create(loc, x, zeroFloat); + return b.create(loc, xMaxZero, SizeSubOne); + }; + + auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode, + Value x, Value SizeSubOne) -> Value { + // Border + if (paddingMode == 1) { + return lambdaBorder(b, loc, x, SizeSubOne); + } + + return x; + }; + auto resultType = cast( getTypeConverter()->convertType(op.getResult().getType())); Value alignCorners = adaptor.getAlignCorners(); Value interMode = adaptor.getInterpolationMode(); + + int64_t paddingModeInt; + matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt)); + SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) dynamicSizes.push_back(rewriter.create(loc, input, 0)); @@ -2599,10 +2619,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value gplus1 = b.create(loc, gr1, oneFloat); Value gPlusMul0 = b.create(loc, gplus0, innerDim0e); Value gPlusMul1 = b.create(loc, gplus1, innerDim1e); - Value result0 = + Value unnorm0 = b.create(loc, gPlusMul0, gr0HalfSelect); - Value result1 = + Value unnorm1 = b.create(loc, gPlusMul1, gr1HalfSelect); + Value result0 = + lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d); + Value result1 = + lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d); Value checkLowerBound0 = b.create( loc, arith::CmpFPredicate::OLT, result0, zeroFloat); Value checkLowerBound1 = b.create( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index d567db79fdf8..e2dbf1c06c8c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -965,6 +965,18 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t // ----- +// CHECK-LABEL: @test_grid_sampler03 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[B0:.*]] = torch.constant.bool true +// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT1]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> +func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "border"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_oldest_pad func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} { // CHECK: %[[int0:.*]] = torch.constant.int 0 diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir index 2a291f721fed..f56881898d67 100644 --- a/test/Conversion/TorchToLinalg/gridsampler.mlir +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -96,6 +96,18 @@ func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte // ----- // CHECK-LABEL: func @grid_sampler4 +// CHECK: #map +// CHECK-DAG: %[[Y49:.*]] = arith.maximumf %[[Y47:.*]], %[[CST0:.*]] : f32 +// CHECK-DAG: %[[Y50:.*]] = arith.minimumf %[[Y49:.*]], %[[Y22:.*]] : f32 +// CHECK-DAG: %[[Y51:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[Y52:.*]] = arith.cmpi eq, %[[Y9:.*]], %[[Y51:.*]] : i64 +// CHECK-DAG: %[[Y53:.*]] = arith.select %[[Y52:.*]], %[[Y47:.*]], %[[Y50:.*]] : f32 +// CHECK-DAG: %[[Y54:.*]] = arith.maximumf %[[Y48:.*]], %[[CST0:.*]] : f32 +// CHECK-DAG: %[[Y55:.*]] = arith.minimumf %[[Y54:.*]], %[[Y23:.*]] : f32 +// CHECK-DAG: %[[Y56:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[Y52:.*]] = arith.cmpi eq, %[[Y9:.*]], %[[Y51:.*]] : i64 +// CHECK-DAG: linalg.yield %[[Y60:.*]] : f32 +// CHECK: return %[[X12:.*]] : !torch.vtensor<[?,?,?,?],f32> func.func @grid_sampler4(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %false = torch.constant.bool 1 %int0 = torch.constant.int 0