diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b653d95dda02..9b4333c79a70 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7159,6 +7159,33 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ }]; } +def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$indices, + AnyTorchListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 3cb219b57938..f93e5c5d178e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3464,6 +3464,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesList = createConstantIntList(binder, rewriter, strides); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, indices, resultShapeList, + stridesList, paddingList); + return success(); + } rewriter.replaceOpWithNewOp( binder.op, resultType, data, indices, resultShapeList, stridesList, paddingList); diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index b918615f8634..f17bec3aa410 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -596,37 +596,51 @@ namespace { // input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3). // What worse, without knowing kernel size we cannot even reliably detect such // cases and this conversion will just return invalid values. -class ConvertAtenMaxUnpool3dOp final - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + +template <> struct DimensionTraits { + static constexpr int64_t Dim = 2; + // unused const variable warning suppression: + static_assert(Dim == Dim); +}; + +template <> struct DimensionTraits { + static constexpr int64_t Dim = 3; + // unused const variable warning suppression: + static_assert(Dim == Dim); +}; + +template +class ConvertAtenMaxUnpoolOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static const int64_t Dim = DimensionTraits::Dim; + + LogicalResult createUnpoolOp(OpTy &op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); auto selfType = cast(self.getType()); - size_t spatial = selfType.getRank() - 2; - ArrayRef inputSize = selfType.getShape().take_back(spatial); + ArrayRef inputSize = selfType.getShape().take_back(Dim); if (ShapedType::isDynamicShape(inputSize)) return rewriter.notifyMatchFailure(op, "input type must be of static shape"); Value indices = adaptor.getIndices(); auto indicesType = cast(indices.getType()); - if (inputSize != indicesType.getShape().take_back(spatial)) + if (inputSize != indicesType.getShape().take_back(Dim)) return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); auto resType = typeConverter->convertType(op.getType()); if (!resType) return rewriter.notifyMatchFailure(op, "invalid result type"); - ArrayRef inferredOutSize = resType.getShape().take_back(spatial); + ArrayRef inferredOutSize = resType.getShape().take_back(Dim); if (ShapedType::isDynamicShape(inferredOutSize)) return rewriter.notifyMatchFailure(op, "output type must be of static shape"); @@ -637,7 +651,7 @@ class ConvertAtenMaxUnpool3dOp final return rewriter.notifyMatchFailure(op, "only support constant int output"); - if (inferredOutSize != ArrayRef(output).take_back(spatial)) + if (inferredOutSize != ArrayRef(output).take_back(Dim)) return rewriter.notifyMatchFailure(op, "Invalid output size"); } SmallVector stride; @@ -653,12 +667,12 @@ class ConvertAtenMaxUnpool3dOp final // TODO: add support for asymmetric padding coming from "onnx.MaxUnpool" // (padding.size() == 6). - if (stride.size() != spatial || padding.size() != spatial) + if (stride.size() != Dim || padding.size() != Dim) return rewriter.notifyMatchFailure( - op, "stride and padding must be of size 3"); + op, "stride and padding must be of size Dim"); int64_t outRank = resType.getRank(); - int64_t NC = outRank - spatial; + int64_t NC = outRank - Dim; for (auto &&[inDim, outDim, str, pad] : llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) { @@ -695,7 +709,7 @@ class ConvertAtenMaxUnpool3dOp final // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) // pad self and indices tensors to avoid out of bounds access. SmallVector expectedInputShape = - llvm::to_vector(resType.getShape().drop_back(spatial)); + llvm::to_vector(resType.getShape().drop_back(Dim)); for (auto &&[str, pad, resSize] : llvm::zip_equal(stride, padding, inferredOutSize)) expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2); @@ -708,7 +722,7 @@ class ConvertAtenMaxUnpool3dOp final SmallVector low(outRank, 0); SmallVector high(NC, 0); for (auto &&[inpSize, outSize] : llvm::zip_equal( - inputSize, ArrayRef(expectedInputShape).take_back(spatial))) { + inputSize, ArrayRef(expectedInputShape).take_back(Dim))) { high.emplace_back(outSize - inpSize); } @@ -827,6 +841,13 @@ class ConvertAtenMaxUnpool3dOp final rewriter.replaceOp(op, result); return success(); } + +public: + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return createUnpoolOp(op, adaptor, rewriter); + } }; } // namespace @@ -1527,8 +1548,12 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( patterns.add>(typeConverter, context); + target.addIllegalOp(); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); target.addIllegalOp(); patterns diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2b7db059bb42..876140afe9dd 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1056,14 +1056,17 @@ def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode) return maxpool3d, indices +def aten〇max_unpool2d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]: + assert (len(self) == 4), "Input be of rank 4" + assert (len(output_size) == 2), "output_size must have 2 elements" + assert (len(self) == len(indices)), "Input and indices must be of the same rank" + return [self[0], self[1], output_size[0], output_size[1]] + def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]: - assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5" + assert (len(self) == 5), "Input be of rank 5" assert (len(output_size) == 3), "output_size must have 3 elements" assert (len(self) == len(indices)), "Input and indices must be of the same rank" - if len(self) == 5: - return [self[0], self[1], output_size[0], output_size[1], output_size[2]] - else: - return [self[0], output_size[0], output_size[1], output_size[2]] + return [self[0], self[1], output_size[0], output_size[1], output_size[2]] def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return input_size @@ -3202,6 +3205,10 @@ def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +def aten〇max_unpool2d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 8812a713faa5..9c22a3539ee2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -622,6 +622,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit("aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", has_canonicalizer=True, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 68119eb57841..eb19e18ac4cd 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1680,7 +1680,7 @@ func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[ // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> // return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4],f32> diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index bd696a67a1a3..663b077c4b3b 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -100,8 +100,8 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)> // CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func @forward_max_unpool -func.func @forward_max_unpool(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func @forward_max_unpool2d +func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %int1 = torch.constant.int 1 %int1_0 = torch.constant.int 1 %int4 = torch.constant.int 4 @@ -113,7 +113,7 @@ func.func @forward_max_unpool(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torc %int2 = torch.constant.int 2 %int2_3 = torch.constant.int 2 %2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list - %3 = torch.aten.max_unpool3d %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + %3 = torch.aten.max_unpool2d %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> // CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64> // CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32>