Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onnx GridSample support for border padding mode #3819

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

std::string padding;
int64_t paddingModeInt;
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
return rewriter.notifyMatchFailure(binder.op,
"padding_mode bind failure");
if (padding != "zeros")
if (padding == "zeros") {
paddingModeInt = 0;
} else if (padding == "border") {
paddingModeInt = 1;
} else {
return rewriter.notifyMatchFailure(
binder.op, "currently only padding_mode : zeros supported");
binder.op,
"currently only padding_mode : zeros and border supported");
}
int64_t align;
if (binder.s64IntegerAttr(align, "align_corners", 0))
return rewriter.notifyMatchFailure(binder.op,
Expand All @@ -156,8 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));

Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
binder.getLoc(), paddingModeInt);

bool alignMode = align;
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
Expand Down
28 changes: 26 additions & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2568,10 +2568,30 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
return res;
};

auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x,
Value SizeSubOne) -> Value {
Value xMaxZero = b.create<arith::MaximumFOp>(loc, x, zeroFloat);
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
};

auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode,
Value x, Value SizeSubOne) -> Value {
// Border
if (paddingMode == 1) {
return lambdaBorder(b, loc, x, SizeSubOne);
}

return x;
};

auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
Value alignCorners = adaptor.getAlignCorners();
Value interMode = adaptor.getInterpolationMode();

int64_t paddingModeInt;
matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt));

SmallVector<Value> dynamicSizes{};
if (resultType.isDynamicDim(0))
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
Expand Down Expand Up @@ -2599,10 +2619,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat);
Value gPlusMul0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e);
Value gPlusMul1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e);
Value result0 =
Value unnorm0 =
b.create<arith::AddFOp>(loc, gPlusMul0, gr0HalfSelect);
Value result1 =
Value unnorm1 =
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
Value result0 =
lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d);
Value result1 =
lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d);
Value checkLowerBound0 = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
Value checkLowerBound1 = b.create<arith::CmpFOp>(
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,18 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t

// -----

// CHECK-LABEL: @test_grid_sampler03
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[B0:.*]] = torch.constant.bool true
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT1]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "border"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @test_oldest_pad
func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} {
// CHECK: %[[int0:.*]] = torch.constant.int 0
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/TorchToLinalg/gridsampler.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte
// -----

// CHECK-LABEL: func @grid_sampler4
// CHECK: #map
// CHECK-DAG: %[[Y49:.*]] = arith.maximumf %[[Y47:.*]], %[[CST0:.*]] : f32
// CHECK-DAG: %[[Y50:.*]] = arith.minimumf %[[Y49:.*]], %[[Y22:.*]] : f32
// CHECK-DAG: %[[Y51:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[Y52:.*]] = arith.cmpi eq, %[[Y9:.*]], %[[Y51:.*]] : i64
// CHECK-DAG: %[[Y53:.*]] = arith.select %[[Y52:.*]], %[[Y47:.*]], %[[Y50:.*]] : f32
// CHECK-DAG: %[[Y54:.*]] = arith.maximumf %[[Y48:.*]], %[[CST0:.*]] : f32
// CHECK-DAG: %[[Y55:.*]] = arith.minimumf %[[Y54:.*]], %[[Y23:.*]] : f32
// CHECK-DAG: %[[Y56:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[Y52:.*]] = arith.cmpi eq, %[[Y9:.*]], %[[Y51:.*]] : i64
// CHECK-DAG: linalg.yield %[[Y60:.*]] : f32
// CHECK: return %[[X12:.*]] : !torch.vtensor<[?,?,?,?],f32>
func.func @grid_sampler4(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%false = torch.constant.bool 1
%int0 = torch.constant.int 0
Expand Down