Skip to content

Commit

Permalink
Generalize max_unpool lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Oct 8, 2024
1 parent d49eabb commit dba2946
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 44 deletions.
25 changes: 0 additions & 25 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7159,31 +7159,6 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
}];
}

def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
5 changes: 0 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3430,11 +3430,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
SmallVector<int64_t> resultShape(resultType.getSizes());
Value resultShapeList =
createConstantIntList(binder, rewriter, resultShape);
if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
binder.op, resultType, data, indices, resultShapeList);
return success();
}

SmallVector<int64_t> padding, strides;
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
Expand Down
17 changes: 9 additions & 8 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,21 +611,22 @@ class ConvertAtenMaxUnpool3dOp final
Value self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());

ArrayRef<int64_t> inputSize = selfType.getShape().take_back(3);
size_t spatial = selfType.getRank() - 2;
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(spatial);
if (ShapedType::isDynamicShape(inputSize))
return rewriter.notifyMatchFailure(op,
"input type must be of static shape");

Value indices = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indices.getType());
if (inputSize != indicesType.getShape().take_back(3))
if (inputSize != indicesType.getShape().take_back(spatial))
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");

auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
if (!resType)
return rewriter.notifyMatchFailure(op, "invalid result type");

ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(3);
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(spatial);
if (ShapedType::isDynamicShape(inferredOutSize))
return rewriter.notifyMatchFailure(op,
"output type must be of static shape");
Expand All @@ -636,7 +637,7 @@ class ConvertAtenMaxUnpool3dOp final
return rewriter.notifyMatchFailure(op,
"only support constant int output");

if (inferredOutSize != ArrayRef(output))
if (inferredOutSize != ArrayRef(output).take_back(spatial))
return rewriter.notifyMatchFailure(op, "Invalid output size");
}
SmallVector<int64_t> stride;
Expand All @@ -652,12 +653,12 @@ class ConvertAtenMaxUnpool3dOp final

// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
// (padding.size() == 6).
if (stride.size() != 3 || padding.size() != 3)
if (stride.size() != spatial || padding.size() != spatial)
return rewriter.notifyMatchFailure(
op, "stride and padding must be of size 3");

int64_t outRank = resType.getRank();
int64_t NC = outRank - 3;
int64_t NC = outRank - spatial;

for (auto &&[inDim, outDim, str, pad] :
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
Expand Down Expand Up @@ -694,7 +695,7 @@ class ConvertAtenMaxUnpool3dOp final
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
// pad self and indices tensors to avoid out of bounds access.
SmallVector<int64_t> expectedInputShape =
llvm::to_vector(resType.getShape().drop_back(3));
llvm::to_vector(resType.getShape().drop_back(spatial));
for (auto &&[str, pad, resSize] :
llvm::zip_equal(stride, padding, inferredOutSize))
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
Expand All @@ -707,7 +708,7 @@ class ConvertAtenMaxUnpool3dOp final
SmallVector<int64_t> low(outRank, 0);
SmallVector<int64_t> high(NC, 0);
for (auto &&[inpSize, outSize] : llvm::zip_equal(
inputSize, ArrayRef(expectedInputShape).take_back(3))) {
inputSize, ArrayRef(expectedInputShape).take_back(spatial))) {
high.emplace_back(outSize - inpSize);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,6 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
)
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True,
Expand Down
16 changes: 11 additions & 5 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1667,23 +1667,29 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc

// -----

// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape
func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-LABEL: func.func @test_maxunpool_2d_export_without_output_shape
func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
return %0 : !torch.vtensor<[1,1,4,4],f32>
}

// -----

// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape
func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-LABEL: func.func @test_maxunpool_3d_export_without_output_shape
func.func @test_maxunpool_3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[INT4:.*]] = torch.constant.int 4
Expand Down
42 changes: 42 additions & 0 deletions test/Conversion/TorchToLinalg/pooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,45 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
// CHECK: } -> tensor<?x?x?x?x?xf32>
return %4 : !torch.vtensor<[?,?,?,?,?],f32>
}

// -----

// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @forward_max_unpool
func.func @forward_max_unpool(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%int1 = torch.constant.int 1
%int1_0 = torch.constant.int 1
%int4 = torch.constant.int 4
%int4_1 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int1, %int1_0, %int4, %int4_1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int0 = torch.constant.int 0
%int0_2 = torch.constant.int 0
%1 = torch.prim.ListConstruct %int0, %int0_2 : (!torch.int, !torch.int) -> !torch.list<int>
%int2 = torch.constant.int 2
%int2_3 = torch.constant.int 2
%2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.max_unpool3d %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>

// CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64>
// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<1x1x2x2xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32>
// CHECK: %[[SHAPE:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4x4xf32>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]], %[[INDICES]] : tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>) outs(%[[SHAPE]] : tensor<?x?x4x4xf32>) {
// CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[CURRENT_INDEX:.*]]: i64, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[INDEX_CAST:.*]] = arith.index_cast %[[CURRENT_INDEX:.*]] : i64 to index
// CHECK-NEXT: %[[INDEX2:.*]] = linalg.index 2 : index
// CHECK-NEXT: %[[INDEX3:.*]] = linalg.index 3 : index
// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : index
// CHECK-NEXT: %[[MULI:.*]] = arith.muli %[[INDEX2:.*]], %[[C4:.*]] : index
// CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[MULI:.*]], %[[INDEX3:.*]] : index
// CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi eq, %[[INDEX_CAST:.*]], %[[ADDI:.*]] : index
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMPI:.*]], %[[CURRENT_VALUE:.*]], %[[CST:.*]] : f32
// CHECK-NEXT: linalg.yield %[[SELECT:.*]] : f32
// CHECK: } -> tensor<?x?x4x4xf32>
return %3 : !torch.vtensor<[1,1,4,4],f32>
}

0 comments on commit dba2946

Please sign in to comment.