From a5db8a14e0a51bd223af1ba2b7994dd7a858ee2c Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:25:49 +0000 Subject: [PATCH 01/11] Add rfft and its conversion to linalg --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++ lib/Conversion/TorchToLinalg/Linear.cpp | 154 ++++++++++++++++++ .../build_tools/torch_ods_gen.py | 1 + 3 files changed, 181 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 44bf8ab2e0d4..2d03438967db 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12980,6 +12980,32 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } +def Torch_AtenFftRfftOp : Torch_Op<"aten.fft_rfft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$n, + Torch_IntType:$dim, + AnyTorchOptionalStringType:$norm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFftRfftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenFftRfftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index a4962d12abdc..da998861d040 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -11,6 +11,7 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Matchers.h" @@ -1373,6 +1374,157 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { }; } // namespace +namespace { + +/// From +/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +/// +/// Creates coefficients based on DFT definition, see +/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType, + bool isRealPart) { + // scale = 2 * pi / N + double scale = 2 * M_PI / matrixType.getDimSize(0); + + SmallVector values; + assert(matrixType.getRank() == 2 && "expected 2D matrix"); + for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { + for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { + double v = scale * i * j; + if (isRealPart) { + v = cos(v); + } else { + v = -sin(v); + } + values.push_back(b.getF32FloatAttr(v)); + } + } + return b.create( + loc, matrixType, DenseFPElementsAttr::get(matrixType, values)); +} + +/// From +/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +Value createLinalgMatmulOnTensors(OpBuilder b, Location loc, + RankedTensorType resultType, Value lhs, + Value rhs) { + Value zero = b.create( + loc, b.getZeroAttr(resultType.getElementType())); + Value emptyTensor = b.create( + loc, resultType.getShape(), resultType.getElementType(), + /*dyn_size=*/ValueRange{}); + Value zeroTensor = + b.create(loc, zero, emptyTensor).getResult(0); + + switch (llvm::cast(lhs.getType()).getRank()) { + case 1: + return b + .create(loc, TypeRange{resultType}, + ValueRange{lhs, rhs}, ValueRange{zeroTensor}) + .getResult(0); + case 2: + return b + .create(loc, TypeRange{resultType}, + ValueRange{lhs, rhs}, ValueRange{zeroTensor}) + .getResult(0); + default: + assert(false && "unhandled matmul type"); + return Value(); + } +} + +struct ConvertAtenFftRfftOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenFftRfftOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + + int64_t dim; + auto dimVal = op.getDim(); + if (isa(dimVal.getType())) { + dim = -1; + } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constant"); + } + + if (!isa(op.getN().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); + } + + if (!isa(op.getNorm().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); + } + + auto inputType = dyn_cast(adaptor.getSelf().getType()); + if (!inputType.hasStaticShape() || inputType.getRank() > 2) { + return rewriter.notifyMatchFailure( + op, "unsupported: only static 1D or 2D FFT is supported"); + } + + const ArrayRef inputShape = inputType.getShape(); + dim += dim < 0 ? inputShape.size() : 0; + + const int64_t fftLength = inputShape[dim]; + const int64_t rank = inputType.getRank(); + const int64_t lastDim = rank - 1; + const int64_t outputFftDim = fftLength / 2 + 1; + const bool needTranspose = dim != lastDim; + + // Transpose if FFT dimension is not the last one + llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); + std::swap(perms[dim], perms[lastDim]); + if (needTranspose) { + self = transposeValue(loc, self, perms, rewriter); + } + + auto matrixType = RankedTensorType::get({fftLength, outputFftDim}, + inputType.getElementType()); + + RankedTensorType newResultType = llvm::cast( + getTypeConverter()->convertType(op.getType())); + + auto componentsType = RankedTensorType::get(newResultType.getShape(), + inputType.getElementType()); + + Value realMatrix = + getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true); + Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType, + self, realMatrix); + + Value imagMatrix = + getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false); + Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType, + self, imagMatrix); + + // Pack components into a complex tensor + Type elementType = newResultType.getElementType(); + auto toComplexBody = [&](OpBuilder &b, Location loc, + ValueRange payloadArgs) { + Value realElem = payloadArgs[0]; + Value imagElem = payloadArgs[1]; + Value complexElem = + b.create(loc, elementType, realElem, imagElem); + b.create(loc, complexElem); + }; + Value complexRes = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, {real, imag}, elementType, toComplexBody); + + // Transpose back + if (needTranspose) { + complexRes = transposeValue(loc, complexRes, perms, rewriter); + } + + rewriter.replaceOpWithNewOp(op, newResultType, complexRes); + return success(); + } +}; + +} // namespace + void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1387,4 +1539,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ea5070a8c0bb..97f78199ce90 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -951,6 +951,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::hann_window.periodic : (int, bool, int?, int?, Device?, bool?) -> (Tensor)" ) emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit( From 6c9da80ef75e3ff537f2719099436f81e45a2407 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 7 Nov 2024 11:04:04 +0000 Subject: [PATCH 02/11] Add rfft to abstract interp lib --- .../Transforms/AbstractInterpLibrary.cpp | 82 +++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 29 +++++++ 2 files changed, 111 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 559726f20659..25761f04b882 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10689,6 +10689,50 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fft_rfft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Expected dim in [-rank, rank-1]\"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.aten.add.int %arg2, %10 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %11 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" %2 = torch.aten.ge.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.aten.lt.int %1, %10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %4, %10 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.add.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten._set_item.t %4, %1, %8 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n" @@ -12766,6 +12810,44 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_rfft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int7 = torch.constant.int 7\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2b7db059bb42..f883ade14fae 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2192,6 +2192,18 @@ def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]: def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self +@check_shape_function([ + Invocation(TensorOfShape(3, 9, 5), None, -2, None) # Second-last dim +]) +def aten〇fft_rfft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: + dim = (dim + len(self)) if dim < 0 else dim + assert dim >= 0 and dim < len(self), "Expected dim in [-rank, rank-1]" + out: List[int] = [] + for s in self: + out.append(s) + out[dim] = self[dim] // 2 + 1 + return out + @check_shape_function([ Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window. ]) @@ -3732,6 +3744,23 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = else: assert False, "Unsupported dtype" +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex32, torch.complex64, torch.complex128, torch.bfloat16})) +def aten〇fft_rfft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + + + @check_dtype_function([ Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32 Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64 From 643c6d52071083024aca588bfa2767e6be13f474 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 7 Nov 2024 13:23:11 +0000 Subject: [PATCH 03/11] Fix wrong component shape when transposing --- lib/Conversion/TorchToLinalg/Linear.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index da998861d040..f6dd45df3b0e 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1459,7 +1459,12 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern { return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); } - auto inputType = dyn_cast(adaptor.getSelf().getType()); + RankedTensorType inputType = + cast(adaptor.getSelf().getType()); + if (!inputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "unsupported: only ranked tensors are supported"); + } if (!inputType.hasStaticShape() || inputType.getRank() > 2) { return rewriter.notifyMatchFailure( op, "unsupported: only static 1D or 2D FFT is supported"); @@ -1474,21 +1479,25 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern { const int64_t outputFftDim = fftLength / 2 + 1; const bool needTranspose = dim != lastDim; + RankedTensorType newResultType = llvm::cast( + getTypeConverter()->convertType(op.getType())); + llvm::SmallVector componentShape(newResultType.getShape()); + // Transpose if FFT dimension is not the last one llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); std::swap(perms[dim], perms[lastDim]); if (needTranspose) { self = transposeValue(loc, self, perms, rewriter); + for (size_t i = 0; i < componentShape.size(); i++) { + componentShape[i] = newResultType.getShape()[perms[i]]; + } } - auto matrixType = RankedTensorType::get({fftLength, outputFftDim}, - inputType.getElementType()); - - RankedTensorType newResultType = llvm::cast( - getTypeConverter()->convertType(op.getType())); + RankedTensorType matrixType = RankedTensorType::get( + {fftLength, outputFftDim}, inputType.getElementType()); - auto componentsType = RankedTensorType::get(newResultType.getShape(), - inputType.getElementType()); + RankedTensorType componentsType = + RankedTensorType::get(componentShape, inputType.getElementType()); Value realMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true); From c073ea23fee6886eafee1f69fca967b6beb7403a Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 7 Nov 2024 13:23:54 +0000 Subject: [PATCH 04/11] Add tests --- .../test_suite/spectral.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py index 8e259fbe0c2a..57a7270f9d09 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py @@ -51,3 +51,43 @@ def forward(self): @register_test_case(module_factory=lambda: AtenHannWindowPeriodicTrueModule()) def AtenHannWindowPeriodicTrueModule_basic(module, tu: TestUtils): module.forward() + + +# ============================================================================== + + +class AtenFftRfft2DLastDim(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([16, 9], torch.float32, True), + ] + ) + def forward(self, input): + return torch.fft.rfft(input, dim=-1) + + +@register_test_case(module_factory=lambda: AtenFftRfft2DLastDim()) +def AtenFftRfft2DLastDim_basic(module, tu: TestUtils): + module.forward(tu.rand(16, 9)) + + +# ============================================================================== + + +class AtenFftRfft2DMiddleDim(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([36, 10], torch.float32, True), + ] + ) + def forward(self, input): + return torch.fft.rfft(input, dim=0) + + +@register_test_case(module_factory=lambda: AtenFftRfft2DMiddleDim()) +def AtenFftRfft2DMiddleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(36, 10)) From ad189a97bbad69fef04eb1b8e6d413d129545731 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 7 Nov 2024 15:12:54 +0000 Subject: [PATCH 05/11] Add unit tests --- test/Conversion/TorchToLinalg/spectral.mlir | 62 +++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 test/Conversion/TorchToLinalg/spectral.mlir diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir new file mode 100644 index 000000000000..af0e4f4f6299 --- /dev/null +++ b/test/Conversion/TorchToLinalg/spectral.mlir @@ -0,0 +1,62 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( +// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> +// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<16x5xf32> +// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[REAL_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<16x5xf32> +// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[IMAG_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<16x5xcomplex> +// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<16x5xf32>, tensor<16x5xf32>) outs(%[[EMPTY_2:.*]] : tensor<16x5xcomplex>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_2:.*]]: f32, %[[OUT:.*]]: complex): +// CHECK: %[[ELEM_COMPLEX:.*]] = complex.create %[[IN:.*]], %[[IN_2:.*]] : complex +// CHECK: linalg.yield %[[ELEM_COMPLEX:.*]] : complex +// CHECK: } -> tensor<16x5xcomplex> +// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[COMPLEX:.*]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[16,5],complex> + +func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { + %int-1 = torch.constant.int -1 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> + return %out : !torch.vtensor<[16,5],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( +// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> +// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> +// CHECK-DAG: %[[EMPTY_0:.*]] = tensor.empty() : tensor<23x36xf32> +// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[INPUT:.*]] : tensor<36x23xf32>) outs(%[[EMPTY_0:.*]] : tensor<23x36xf32>) permutation = [1, 0] +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<23x19xf32> +// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[REAL_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_0:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<23x19xf32> +// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_2:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[IMAG_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[EMPTY_3:.*]] = tensor.empty() : tensor<23x19xcomplex> +// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<23x19xf32>, tensor<23x19xf32>) outs(%[[EMPTY_3:.*]] : tensor<23x19xcomplex>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_3:.*]]: f32, %[[OUT:.*]]: complex): +// CHECK: %[[EMPTY_02:.*]] = complex.create %[[IN:.*]], %[[IN_3:.*]] : complex +// CHECK: linalg.yield %[[EMPTY_02:.*]] : complex +// CHECK: } -> tensor<23x19xcomplex> +// CHECK: %[[EMPTY_4:.*]] = tensor.empty() : tensor<19x23xcomplex> +// CHECK: %[[TRANSPOSED_2:.*]] = linalg.transpose ins(%[[COMPLEX:.*]] : tensor<23x19xcomplex>) outs(%[[EMPTY_4:.*]] : tensor<19x23xcomplex>) permutation = [1, 0] +// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_2:.*]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[19,23],complex> +func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> + return %out : !torch.vtensor<[19,23],complex> +} From 9194e6e4bd628e03bd3d74fb007565132b15b974 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Mon, 11 Nov 2024 18:30:41 +0000 Subject: [PATCH 06/11] Add tests to ONNX xfail set --- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 052eceb5ac4a..f44031b377cb 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2675,6 +2675,8 @@ "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", From 721c8f1c00c4421882ad27464b67aeb7fad1a146 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Tue, 12 Nov 2024 12:52:42 +0000 Subject: [PATCH 07/11] Change to decompose implementation --- lib/Conversion/TorchToLinalg/Linear.cpp | 162 --------------- .../Torch/Transforms/DecomposeComplexOps.cpp | 191 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + test/Conversion/TorchToLinalg/spectral.mlir | 62 ------ test/Dialect/Torch/decompose-complex-ops.mlir | 63 ++++++ 5 files changed, 255 insertions(+), 224 deletions(-) delete mode 100644 test/Conversion/TorchToLinalg/spectral.mlir diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index f6dd45df3b0e..557e6cd430ee 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1374,166 +1374,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { }; } // namespace -namespace { - -/// From -/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp -/// -/// Creates coefficients based on DFT definition, see -/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. -Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType, - bool isRealPart) { - // scale = 2 * pi / N - double scale = 2 * M_PI / matrixType.getDimSize(0); - - SmallVector values; - assert(matrixType.getRank() == 2 && "expected 2D matrix"); - for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { - for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { - double v = scale * i * j; - if (isRealPart) { - v = cos(v); - } else { - v = -sin(v); - } - values.push_back(b.getF32FloatAttr(v)); - } - } - return b.create( - loc, matrixType, DenseFPElementsAttr::get(matrixType, values)); -} - -/// From -/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp -Value createLinalgMatmulOnTensors(OpBuilder b, Location loc, - RankedTensorType resultType, Value lhs, - Value rhs) { - Value zero = b.create( - loc, b.getZeroAttr(resultType.getElementType())); - Value emptyTensor = b.create( - loc, resultType.getShape(), resultType.getElementType(), - /*dyn_size=*/ValueRange{}); - Value zeroTensor = - b.create(loc, zero, emptyTensor).getResult(0); - - switch (llvm::cast(lhs.getType()).getRank()) { - case 1: - return b - .create(loc, TypeRange{resultType}, - ValueRange{lhs, rhs}, ValueRange{zeroTensor}) - .getResult(0); - case 2: - return b - .create(loc, TypeRange{resultType}, - ValueRange{lhs, rhs}, ValueRange{zeroTensor}) - .getResult(0); - default: - assert(false && "unhandled matmul type"); - return Value(); - } -} - -struct ConvertAtenFftRfftOp final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AtenFftRfftOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = adaptor.getSelf(); - - int64_t dim; - auto dimVal = op.getDim(); - if (isa(dimVal.getType())) { - dim = -1; - } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure( - op, "unimplemented: requires dim to be constant"); - } - - if (!isa(op.getN().getType())) { - return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); - } - - if (!isa(op.getNorm().getType())) { - return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); - } - - RankedTensorType inputType = - cast(adaptor.getSelf().getType()); - if (!inputType.hasRank()) { - return rewriter.notifyMatchFailure( - op, "unsupported: only ranked tensors are supported"); - } - if (!inputType.hasStaticShape() || inputType.getRank() > 2) { - return rewriter.notifyMatchFailure( - op, "unsupported: only static 1D or 2D FFT is supported"); - } - - const ArrayRef inputShape = inputType.getShape(); - dim += dim < 0 ? inputShape.size() : 0; - - const int64_t fftLength = inputShape[dim]; - const int64_t rank = inputType.getRank(); - const int64_t lastDim = rank - 1; - const int64_t outputFftDim = fftLength / 2 + 1; - const bool needTranspose = dim != lastDim; - - RankedTensorType newResultType = llvm::cast( - getTypeConverter()->convertType(op.getType())); - llvm::SmallVector componentShape(newResultType.getShape()); - - // Transpose if FFT dimension is not the last one - llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); - std::swap(perms[dim], perms[lastDim]); - if (needTranspose) { - self = transposeValue(loc, self, perms, rewriter); - for (size_t i = 0; i < componentShape.size(); i++) { - componentShape[i] = newResultType.getShape()[perms[i]]; - } - } - - RankedTensorType matrixType = RankedTensorType::get( - {fftLength, outputFftDim}, inputType.getElementType()); - - RankedTensorType componentsType = - RankedTensorType::get(componentShape, inputType.getElementType()); - - Value realMatrix = - getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true); - Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType, - self, realMatrix); - - Value imagMatrix = - getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false); - Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType, - self, imagMatrix); - - // Pack components into a complex tensor - Type elementType = newResultType.getElementType(); - auto toComplexBody = [&](OpBuilder &b, Location loc, - ValueRange payloadArgs) { - Value realElem = payloadArgs[0]; - Value imagElem = payloadArgs[1]; - Value complexElem = - b.create(loc, elementType, realElem, imagElem); - b.create(loc, complexElem); - }; - Value complexRes = torch_to_linalg::createElementwiseLinalgGeneric( - rewriter, loc, {real, imag}, elementType, toComplexBody); - - // Transpose back - if (needTranspose) { - complexRes = transposeValue(loc, complexRes, perms, rewriter); - } - - rewriter.replaceOpWithNewOp(op, newResultType, complexRes); - return success(); - } -}; - -} // namespace - void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1548,6 +1388,4 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 29c176f96afd..888b13f1c3d6 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9020,6 +9020,196 @@ class DecomposeAtenTopkOp : public OpRewritePattern { }; } // namespace +namespace { + +/// Creates coefficients based on DFT definition, see +/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, + ValueTensorType matrixType, bool isRealPart) { + // scale = 2 * pi / N + double scale = 2 * M_PI / matrixType.getSizes()[0]; + + SmallVector values; + assert(matrixType.getSizes().size() == 2 && "expected 2D matrix"); + for (auto i : llvm::seq(0, matrixType.getSizes()[0])) { + for (auto j : llvm::seq(0, matrixType.getSizes()[1])) { + double v = scale * i * j; + if (isRealPart) { + v = cos(v); + } else { + v = -sin(v); + } + values.push_back(rewriter.getF32FloatAttr(v)); + } + } + + return rewriter.create( + loc, matrixType, + DenseElementsAttr::get(matrixType.toBuiltinTensor(), + ArrayRef(values))); +} + +Value createBatchMatmul(PatternRewriter &rewriter, Location loc, Value lhs, + Value rhs) { + + BaseTensorType lhsType = cast(lhs.getType()); + assert(lhsType && lhsType.hasSizes()); + const ArrayRef lhsShape = lhsType.getSizes(); + assert(lhsShape.size() >= 2); + BaseTensorType rhsType = cast(rhs.getType()); + assert(rhsType && rhsType.hasSizes()); + const ArrayRef rhsShape = rhsType.getSizes(); + assert(rhsShape.size() >= 2); + assert(rhsShape[rhsShape.size() - 2] == lhsShape[lhsShape.size() - 1]); + + SmallVector resShape(lhsShape); + resShape[resShape.size() - 1] = rhsShape[rhsShape.size() - 1]; + + Type dtype = lhsType.getOptionalDtype(); + + ValueTensorType resType = + ValueTensorType::get(rewriter.getContext(), resShape, dtype); + return rewriter.create(loc, resType, lhs, rhs); +} + +class DecomposeAtenFftRfftOp final : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFftRfftOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + int64_t dim; + auto dimVal = op.getDim(); + if (isa(dimVal.getType())) { + dim = -1; + } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constant"); + } + + if (!isa(op.getN().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); + } + + if (!isa(op.getNorm().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); + } + + BaseTensorType inputType = cast(self.getType()); + + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "unsupported: only ranked tensors are supported"); + } + + const ArrayRef inputShape = inputType.getSizes(); + dim += dim < 0 ? inputShape.size() : 0; + + const int64_t fftLength = inputShape[dim]; + if (fftLength == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "unsupported: input signal length must be known"); + } + const int64_t rank = inputShape.size(); + const int64_t lastDim = rank - 1; + const int64_t outputFftDim = fftLength / 2 + 1; + const bool needTranspose = dim != lastDim; + + auto transposeValue = [](PatternRewriter &rewriter, Location loc, + Value input, int64_t dimA, int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(cast(input.getType()), dimA, + dimB, transposedType))) + return failure(); + Value cstDimA = + rewriter.create(loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = + rewriter.create(loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create(loc, transposedType, + input, cstDimA, cstDimB); + return success(); + }; + + // Transpose if FFT dimension is not the last one + if (needTranspose) { + if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) { + return failure(); + } + } + + // lhs = unsqueeze(self, -2) : (D x 1 x fftLength), D = [D_1, D_2, ...] + Value unsqueezeDim = + rewriter.create(loc, rewriter.getI64IntegerAttr(-2)); + auto unsqueezed = unsqueezeTensor(rewriter, op, self, unsqueezeDim); + if (failed(unsqueezed)) + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueezed tensor"); + Value lhs = *unsqueezed; + Type dtype = inputType.getOptionalDtype(); + + Value real, complex; + + for (const bool isRealPart : {true, false}) { + + // coeff : (fftLength x outputFftDim) + ValueTensorType matrixType = ValueTensorType::get( + op.getContext(), SmallVector{fftLength, outputFftDim}, + dtype); + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType, + /*isRealPart=*/isRealPart); + + // X = matmul(lhs, coeff) : (D x 1 x outputFftDim) + Value matmulRes = createBatchMatmul(rewriter, loc, lhs, coeffMatrix); + + // Y = squeeze(X, -2) : (D x outputFftDim) + auto squeezed = squeezeTensor(rewriter, op, loc, -2, matmulRes); + if (failed(squeezed)) + return rewriter.notifyMatchFailure(op, + "cannot generate squeezed tensor"); + + if (isRealPart) { + real = *squeezed; + } else { + complex = *squeezed; + } + } + + // Pack components into a complex tensor + BaseTensorType realType = cast(real.getType()); + SmallVector stackSizes(realType.getSizes()); + stackSizes.push_back(2); + Value sequence = rewriter.create( + loc, ListType::get(op.getContext(), realType), + ValueRange{real, complex}); + Type stackType = realType.getWithSizesAndDtype(stackSizes, dtype); + Value cstMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value stack = + rewriter.create(loc, stackType, sequence, cstMinusOne); + Type complexResType = ValueTensorType::get( + op.getContext(), realType.getSizes(), ComplexType::get(dtype)); + Value complexRes = + rewriter.create(loc, complexResType, stack); + + // Transpose back + if (needTranspose) { + if (failed(transposeValue(rewriter, loc, complexRes, dim, lastDim, + complexRes))) { + return failure(); + } + } + + rewriter.replaceOp(op, {complexRes}); + + return success(); + } +}; + +} // namespace + namespace { // Decompose `aten.hann_window` into `aten.arange.start`, `aten.mul.Scalar`, // `aten.sin` and `aten.square` or into `aten.ones` in the trivial case @@ -10011,6 +10201,7 @@ class DecomposeComplexOpsPass patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ebc43faa595c..9e54cc61007f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -547,6 +547,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir deleted file mode 100644 index af0e4f4f6299..000000000000 --- a/test/Conversion/TorchToLinalg/spectral.mlir +++ /dev/null @@ -1,62 +0,0 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( -// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { -// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> -// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> -// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<16x5xf32> -// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[REAL_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<16x5xf32> -// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[IMAG_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<16x5xcomplex> -// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<16x5xf32>, tensor<16x5xf32>) outs(%[[EMPTY_2:.*]] : tensor<16x5xcomplex>) { -// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_2:.*]]: f32, %[[OUT:.*]]: complex): -// CHECK: %[[ELEM_COMPLEX:.*]] = complex.create %[[IN:.*]], %[[IN_2:.*]] : complex -// CHECK: linalg.yield %[[ELEM_COMPLEX:.*]] : complex -// CHECK: } -> tensor<16x5xcomplex> -// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[COMPLEX:.*]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> -// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[16,5],complex> - -func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { - %int-1 = torch.constant.int -1 - %none = torch.constant.none - %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> - return %out : !torch.vtensor<[16,5],complex> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( -// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { -// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> -// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> -// CHECK-DAG: %[[EMPTY_0:.*]] = tensor.empty() : tensor<23x36xf32> -// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[INPUT:.*]] : tensor<36x23xf32>) outs(%[[EMPTY_0:.*]] : tensor<23x36xf32>) permutation = [1, 0] -// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<23x19xf32> -// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[REAL_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_0:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<23x19xf32> -// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_2:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[IMAG_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[EMPTY_3:.*]] = tensor.empty() : tensor<23x19xcomplex> -// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<23x19xf32>, tensor<23x19xf32>) outs(%[[EMPTY_3:.*]] : tensor<23x19xcomplex>) { -// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_3:.*]]: f32, %[[OUT:.*]]: complex): -// CHECK: %[[EMPTY_02:.*]] = complex.create %[[IN:.*]], %[[IN_3:.*]] : complex -// CHECK: linalg.yield %[[EMPTY_02:.*]] : complex -// CHECK: } -> tensor<23x19xcomplex> -// CHECK: %[[EMPTY_4:.*]] = tensor.empty() : tensor<19x23xcomplex> -// CHECK: %[[TRANSPOSED_2:.*]] = linalg.transpose ins(%[[COMPLEX:.*]] : tensor<23x19xcomplex>) outs(%[[EMPTY_4:.*]] : tensor<19x23xcomplex>) permutation = [1, 0] -// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_2:.*]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> -// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[19,23],complex> -func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { - %int0 = torch.constant.int 0 - %none = torch.constant.none - %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> - return %out : !torch.vtensor<[19,23],complex> -} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index f938a2637835..ac6ddc7585d8 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -171,3 +171,66 @@ func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.v %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> return %0 : !torch.vtensor<[?],f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK: %[[INTM1:.*]] = torch.constant.int -1 +// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32> +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32> +// CHECK: %[[INTM2:.*]] = torch.constant.int -2 +// CHECK: %[[VAR2:.*]] = torch.aten.unsqueeze %[[ARG0:.*]], %[[INTM2:.*]] : !torch.vtensor<[16,9],f32>, !torch.int -> !torch.vtensor<[16,1,9],f32> +// CHECK: %[[VAR3:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR1:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32> +// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." +// CHECK: %[[VAR4:.*]] = torch.aten.squeeze.dim %[[VAR3:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32> +// CHECK: %[[VAR5:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR0:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32> +// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." +// CHECK: %[[VAR6:.*]] = torch.aten.squeeze.dim %[[VAR5:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32> +// CHECK: %[[VAR7:.*]] = torch.aten.unsqueeze %[[VAR4:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32> +// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR6:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32> +// CHECK: %[[VAR9:.*]] = torch.prim.ListConstruct %[[VAR7:.*]], %[[VAR8:.*]] : (!torch.vtensor<[16,5,1],f32>, !torch.vtensor<[16,5,1],f32>) -> !torch.list +// CHECK: %[[VAR10:.*]] = torch.aten.cat %[[VAR9:.*]], %[[INTM1:.*]] : !torch.list, !torch.int -> !torch.vtensor<[16,5,2],f32> +// CHECK: %[[VAR11:.*]] = torch.aten.view_as_complex %[[VAR10:.*]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR11:.*]] : !torch.vtensor<[16,5],complex> +func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { + %int-1 = torch.constant.int -1 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> + return %out : !torch.vtensor<[16,5],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK: %[[INTM1:.*]] = torch.constant.int -1 +// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32> +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32> +// CHECK: %[[INTM2:.*]] = torch.constant.int -2 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[VAR2:.*]] = torch.aten.transpose.int %[[ARG0:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> +// CHECK: %[[VAR3:.*]] = torch.aten.unsqueeze %[[VAR2:.*]], %[[INTM2:.*]] : !torch.vtensor<[23,36],f32>, !torch.int -> !torch.vtensor<[23,1,36],f32> +// CHECK: %[[VAR4:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR1:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32> +// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." +// CHECK: %[[VAR5:.*]] = torch.aten.squeeze.dim %[[VAR4:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32> +// CHECK: %[[VAR6:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR0:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32> +// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." +// CHECK: %[[VAR7:.*]] = torch.aten.squeeze.dim %[[VAR6:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32> +// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR5:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32> +// CHECK: %[[VAR9:.*]] = torch.aten.unsqueeze %[[VAR7:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32> +// CHECK: %[[VAR10:.*]] = torch.prim.ListConstruct %[[VAR8:.*]], %[[VAR9:.*]] : (!torch.vtensor<[23,19,1],f32>, !torch.vtensor<[23,19,1],f32>) -> !torch.list +// CHECK: %[[VAR11:.*]] = torch.aten.cat %[[VAR10:.*]], %[[INTM1:.*]] : !torch.list, !torch.int -> !torch.vtensor<[23,19,2],f32> +// CHECK: %[[VAR12:.*]] = torch.aten.view_as_complex %[[VAR11:.*]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> +// CHECK: %[[VAR13:.*]] = torch.aten.transpose.int %[[VAR12:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR13:.*]] : !torch.vtensor<[19,23],complex> +func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> + return %out : !torch.vtensor<[19,23],complex> +} From 7f9fb25d8a56c1dffe3c9c7e06f58ace6a815f4d Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Tue, 12 Nov 2024 15:45:48 +0000 Subject: [PATCH 08/11] Add tests to FX_IMPORTER_STABLEHLO_XFAIL_SET too --- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f44031b377cb..0c859fd73ac0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -620,6 +620,8 @@ "AtenDiagEmbedOffsetDiag_basic", "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagSumExample_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", From 048dc5518ef77c1135ce41d052f9d7ad4292bafd Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:50:14 +0000 Subject: [PATCH 09/11] Add lowering to Linalg back --- lib/Conversion/TorchToLinalg/Linear.cpp | 162 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 - test/Conversion/TorchToLinalg/spectral.mlir | 62 +++++++ 3 files changed, 224 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/TorchToLinalg/spectral.mlir diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 557e6cd430ee..f6dd45df3b0e 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1374,6 +1374,166 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { }; } // namespace +namespace { + +/// From +/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +/// +/// Creates coefficients based on DFT definition, see +/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType, + bool isRealPart) { + // scale = 2 * pi / N + double scale = 2 * M_PI / matrixType.getDimSize(0); + + SmallVector values; + assert(matrixType.getRank() == 2 && "expected 2D matrix"); + for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { + for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { + double v = scale * i * j; + if (isRealPart) { + v = cos(v); + } else { + v = -sin(v); + } + values.push_back(b.getF32FloatAttr(v)); + } + } + return b.create( + loc, matrixType, DenseFPElementsAttr::get(matrixType, values)); +} + +/// From +/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +Value createLinalgMatmulOnTensors(OpBuilder b, Location loc, + RankedTensorType resultType, Value lhs, + Value rhs) { + Value zero = b.create( + loc, b.getZeroAttr(resultType.getElementType())); + Value emptyTensor = b.create( + loc, resultType.getShape(), resultType.getElementType(), + /*dyn_size=*/ValueRange{}); + Value zeroTensor = + b.create(loc, zero, emptyTensor).getResult(0); + + switch (llvm::cast(lhs.getType()).getRank()) { + case 1: + return b + .create(loc, TypeRange{resultType}, + ValueRange{lhs, rhs}, ValueRange{zeroTensor}) + .getResult(0); + case 2: + return b + .create(loc, TypeRange{resultType}, + ValueRange{lhs, rhs}, ValueRange{zeroTensor}) + .getResult(0); + default: + assert(false && "unhandled matmul type"); + return Value(); + } +} + +struct ConvertAtenFftRfftOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenFftRfftOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + + int64_t dim; + auto dimVal = op.getDim(); + if (isa(dimVal.getType())) { + dim = -1; + } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constant"); + } + + if (!isa(op.getN().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); + } + + if (!isa(op.getNorm().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); + } + + RankedTensorType inputType = + cast(adaptor.getSelf().getType()); + if (!inputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "unsupported: only ranked tensors are supported"); + } + if (!inputType.hasStaticShape() || inputType.getRank() > 2) { + return rewriter.notifyMatchFailure( + op, "unsupported: only static 1D or 2D FFT is supported"); + } + + const ArrayRef inputShape = inputType.getShape(); + dim += dim < 0 ? inputShape.size() : 0; + + const int64_t fftLength = inputShape[dim]; + const int64_t rank = inputType.getRank(); + const int64_t lastDim = rank - 1; + const int64_t outputFftDim = fftLength / 2 + 1; + const bool needTranspose = dim != lastDim; + + RankedTensorType newResultType = llvm::cast( + getTypeConverter()->convertType(op.getType())); + llvm::SmallVector componentShape(newResultType.getShape()); + + // Transpose if FFT dimension is not the last one + llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); + std::swap(perms[dim], perms[lastDim]); + if (needTranspose) { + self = transposeValue(loc, self, perms, rewriter); + for (size_t i = 0; i < componentShape.size(); i++) { + componentShape[i] = newResultType.getShape()[perms[i]]; + } + } + + RankedTensorType matrixType = RankedTensorType::get( + {fftLength, outputFftDim}, inputType.getElementType()); + + RankedTensorType componentsType = + RankedTensorType::get(componentShape, inputType.getElementType()); + + Value realMatrix = + getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true); + Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType, + self, realMatrix); + + Value imagMatrix = + getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false); + Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType, + self, imagMatrix); + + // Pack components into a complex tensor + Type elementType = newResultType.getElementType(); + auto toComplexBody = [&](OpBuilder &b, Location loc, + ValueRange payloadArgs) { + Value realElem = payloadArgs[0]; + Value imagElem = payloadArgs[1]; + Value complexElem = + b.create(loc, elementType, realElem, imagElem); + b.create(loc, complexElem); + }; + Value complexRes = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, {real, imag}, elementType, toComplexBody); + + // Transpose back + if (needTranspose) { + complexRes = transposeValue(loc, complexRes, perms, rewriter); + } + + rewriter.replaceOpWithNewOp(op, newResultType, complexRes); + return success(); + } +}; + +} // namespace + void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1388,4 +1548,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 9e54cc61007f..ebc43faa595c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -547,7 +547,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir new file mode 100644 index 000000000000..af0e4f4f6299 --- /dev/null +++ b/test/Conversion/TorchToLinalg/spectral.mlir @@ -0,0 +1,62 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( +// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> +// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<16x5xf32> +// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[REAL_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<16x5xf32> +// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[IMAG_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> +// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<16x5xcomplex> +// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<16x5xf32>, tensor<16x5xf32>) outs(%[[EMPTY_2:.*]] : tensor<16x5xcomplex>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_2:.*]]: f32, %[[OUT:.*]]: complex): +// CHECK: %[[ELEM_COMPLEX:.*]] = complex.create %[[IN:.*]], %[[IN_2:.*]] : complex +// CHECK: linalg.yield %[[ELEM_COMPLEX:.*]] : complex +// CHECK: } -> tensor<16x5xcomplex> +// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[COMPLEX:.*]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[16,5],complex> + +func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { + %int-1 = torch.constant.int -1 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> + return %out : !torch.vtensor<[16,5],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( +// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> +// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> +// CHECK-DAG: %[[EMPTY_0:.*]] = tensor.empty() : tensor<23x36xf32> +// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[INPUT:.*]] : tensor<36x23xf32>) outs(%[[EMPTY_0:.*]] : tensor<23x36xf32>) permutation = [1, 0] +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<23x19xf32> +// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[REAL_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_0:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<23x19xf32> +// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_2:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[IMAG_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> +// CHECK: %[[EMPTY_3:.*]] = tensor.empty() : tensor<23x19xcomplex> +// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<23x19xf32>, tensor<23x19xf32>) outs(%[[EMPTY_3:.*]] : tensor<23x19xcomplex>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_3:.*]]: f32, %[[OUT:.*]]: complex): +// CHECK: %[[EMPTY_02:.*]] = complex.create %[[IN:.*]], %[[IN_3:.*]] : complex +// CHECK: linalg.yield %[[EMPTY_02:.*]] : complex +// CHECK: } -> tensor<23x19xcomplex> +// CHECK: %[[EMPTY_4:.*]] = tensor.empty() : tensor<19x23xcomplex> +// CHECK: %[[TRANSPOSED_2:.*]] = linalg.transpose ins(%[[COMPLEX:.*]] : tensor<23x19xcomplex>) outs(%[[EMPTY_4:.*]] : tensor<19x23xcomplex>) permutation = [1, 0] +// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_2:.*]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[19,23],complex> +func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> + return %out : !torch.vtensor<[19,23],complex> +} From 740de4f2608370dc935ab9082b779ec48d1345ca Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Tue, 19 Nov 2024 13:46:25 +0000 Subject: [PATCH 10/11] Address review feedback --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 2 + lib/Conversion/TorchToLinalg/Linear.cpp | 6 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 119 ++++++------------ lib/Dialect/Torch/Utils/Utils.cpp | 12 ++ test/Dialect/Torch/decompose-complex-ops.mlir | 67 ++++------ 5 files changed, 79 insertions(+), 127 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index cf31c8f9735a..dd13c40facf8 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -20,6 +20,8 @@ namespace Torch { int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); +Value toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput, Torch::IntType intType); bool getListConstructElements(Value v, SmallVectorImpl &elems); /// Returns the index indicated by `v` for a list of given `length`. /// If the index is negative, it is adjusted to `length` + `v`. diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index f6dd45df3b0e..850e292fbbab 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1391,11 +1391,7 @@ Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType, for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { double v = scale * i * j; - if (isRealPart) { - v = cos(v); - } else { - v = -sin(v); - } + v = isRealPart ? cos(v) : -sin(v); values.push_back(b.getF32FloatAttr(v)); } } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 888b13f1c3d6..64d9c769f755 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9024,8 +9024,10 @@ namespace { /// Creates coefficients based on DFT definition, see /// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +/// Even indices of the second dimension are for the real components of the +/// output. Odd indices for the imaginary components. Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, - ValueTensorType matrixType, bool isRealPart) { + ValueTensorType matrixType) { // scale = 2 * pi / N double scale = 2 * M_PI / matrixType.getSizes()[0]; @@ -9033,12 +9035,9 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, assert(matrixType.getSizes().size() == 2 && "expected 2D matrix"); for (auto i : llvm::seq(0, matrixType.getSizes()[0])) { for (auto j : llvm::seq(0, matrixType.getSizes()[1])) { - double v = scale * i * j; - if (isRealPart) { - v = cos(v); - } else { - v = -sin(v); - } + const bool isImagPart = j % 2; + double v = scale * i * (j / 2); + v = isImagPart ? -sin(v) : cos(v); values.push_back(rewriter.getF32FloatAttr(v)); } } @@ -9049,29 +9048,6 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, ArrayRef(values))); } -Value createBatchMatmul(PatternRewriter &rewriter, Location loc, Value lhs, - Value rhs) { - - BaseTensorType lhsType = cast(lhs.getType()); - assert(lhsType && lhsType.hasSizes()); - const ArrayRef lhsShape = lhsType.getSizes(); - assert(lhsShape.size() >= 2); - BaseTensorType rhsType = cast(rhs.getType()); - assert(rhsType && rhsType.hasSizes()); - const ArrayRef rhsShape = rhsType.getSizes(); - assert(rhsShape.size() >= 2); - assert(rhsShape[rhsShape.size() - 2] == lhsShape[lhsShape.size() - 1]); - - SmallVector resShape(lhsShape); - resShape[resShape.size() - 1] = rhsShape[rhsShape.size() - 1]; - - Type dtype = lhsType.getOptionalDtype(); - - ValueTensorType resType = - ValueTensorType::get(rewriter.getContext(), resShape, dtype); - return rewriter.create(loc, resType, lhs, rhs); -} - class DecomposeAtenFftRfftOp final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -9133,66 +9109,51 @@ class DecomposeAtenFftRfftOp final : public OpRewritePattern { return success(); }; + SmallVector lhsShape(inputShape); // Transpose if FFT dimension is not the last one if (needTranspose) { if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) { return failure(); } + std::swap(lhsShape[dim], lhsShape[lastDim]); } + // self : (D_0 x ... x D_m x fftLength) - // lhs = unsqueeze(self, -2) : (D x 1 x fftLength), D = [D_1, D_2, ...] - Value unsqueezeDim = - rewriter.create(loc, rewriter.getI64IntegerAttr(-2)); - auto unsqueezed = unsqueezeTensor(rewriter, op, self, unsqueezeDim); - if (failed(unsqueezed)) - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueezed tensor"); - Value lhs = *unsqueezed; Type dtype = inputType.getOptionalDtype(); - Value real, complex; - - for (const bool isRealPart : {true, false}) { - - // coeff : (fftLength x outputFftDim) - ValueTensorType matrixType = ValueTensorType::get( - op.getContext(), SmallVector{fftLength, outputFftDim}, - dtype); - Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType, - /*isRealPart=*/isRealPart); - - // X = matmul(lhs, coeff) : (D x 1 x outputFftDim) - Value matmulRes = createBatchMatmul(rewriter, loc, lhs, coeffMatrix); - - // Y = squeeze(X, -2) : (D x outputFftDim) - auto squeezed = squeezeTensor(rewriter, op, loc, -2, matmulRes); - if (failed(squeezed)) - return rewriter.notifyMatchFailure(op, - "cannot generate squeezed tensor"); - - if (isRealPart) { - real = *squeezed; - } else { - complex = *squeezed; - } - } - - // Pack components into a complex tensor - BaseTensorType realType = cast(real.getType()); - SmallVector stackSizes(realType.getSizes()); - stackSizes.push_back(2); - Value sequence = rewriter.create( - loc, ListType::get(op.getContext(), realType), - ValueRange{real, complex}); - Type stackType = realType.getWithSizesAndDtype(stackSizes, dtype); + // coeff : (fftLength x outputFftDim*2) + ValueTensorType matrixType = ValueTensorType::get( + op.getContext(), SmallVector{fftLength, outputFftDim * 2}, + dtype); + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType); + + // X = matmul(self, coeff) : (D_0 x ... x D_m x outputFftDim*2) + SmallVector matmulShape(lhsShape.begin(), lhsShape.end() - 1); + matmulShape.push_back(outputFftDim * 2); + ValueTensorType matmulType = + ValueTensorType::get(op.getContext(), matmulShape, dtype); + Value flatRes = + rewriter.create(loc, matmulType, self, coeffMatrix); + + // Y = unflatten(X, -1, [outputFftDim, 2]) + // : (D_0 x ... x D_m x outputFftDim x 2) + // Z = view_as_complex(Y) : complex(D_0 x ... x D_m x outputFftDim) + SmallVector complexResShape(matmulShape); + complexResShape.back() = outputFftDim; + SmallVector unflattenedResShape(complexResShape); + unflattenedResShape.push_back(2); + Type unflattenedResType = + ValueTensorType::get(op.getContext(), unflattenedResShape, dtype); Value cstMinusOne = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); - Value stack = - rewriter.create(loc, stackType, sequence, cstMinusOne); - Type complexResType = ValueTensorType::get( - op.getContext(), realType.getSizes(), ComplexType::get(dtype)); - Value complexRes = - rewriter.create(loc, complexResType, stack); + Value unflattenSizes = toIntListConstruct( + rewriter, loc, {outputFftDim, 2}, IntType::get(rewriter.getContext())); + Value unflattenedRes = rewriter.create( + loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes); + Type complexResType = ValueTensorType::get(op.getContext(), complexResShape, + ComplexType::get(dtype)); + Value complexRes = rewriter.create(loc, complexResType, + unflattenedRes); // Transpose back if (needTranspose) { diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 664bbb2d5d8e..77840d206e02 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -36,6 +36,18 @@ Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { return dim; } +Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput, + Torch::IntType intType) { + SmallVector cstValues; + for (int64_t i : cstInput) { + cstValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + loc, Torch::ListType::get(intType), cstValues); +} + bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index ac6ddc7585d8..bf37a484720f 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -175,26 +175,16 @@ func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.v // ----- // CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { -// CHECK: %[[INTM1:.*]] = torch.constant.int -1 -// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32> -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32> -// CHECK: %[[INTM2:.*]] = torch.constant.int -2 -// CHECK: %[[VAR2:.*]] = torch.aten.unsqueeze %[[ARG0:.*]], %[[INTM2:.*]] : !torch.vtensor<[16,9],f32>, !torch.int -> !torch.vtensor<[16,1,9],f32> -// CHECK: %[[VAR3:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR1:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR4:.*]] = torch.aten.squeeze.dim %[[VAR3:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32> -// CHECK: %[[VAR5:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR0:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR6:.*]] = torch.aten.squeeze.dim %[[VAR5:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32> -// CHECK: %[[VAR7:.*]] = torch.aten.unsqueeze %[[VAR4:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32> -// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR6:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32> -// CHECK: %[[VAR9:.*]] = torch.prim.ListConstruct %[[VAR7:.*]], %[[VAR8:.*]] : (!torch.vtensor<[16,5,1],f32>, !torch.vtensor<[16,5,1],f32>) -> !torch.list -// CHECK: %[[VAR10:.*]] = torch.aten.cat %[[VAR9:.*]], %[[INTM1:.*]] : !torch.list, !torch.int -> !torch.vtensor<[16,5,2],f32> -// CHECK: %[[VAR11:.*]] = torch.aten.view_as_complex %[[VAR10:.*]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> -// CHECK: return %[[VAR11:.*]] : !torch.vtensor<[16,5],complex> +// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5 +// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16 +// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32> +// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32> +// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list -> !torch.vtensor<[16,5,2],f32> +// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { %int-1 = torch.constant.int -1 %none = torch.constant.none @@ -205,29 +195,20 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> // ----- // CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { -// CHECK: %[[INTM1:.*]] = torch.constant.int -1 -// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32> -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32> -// CHECK: %[[INTM2:.*]] = torch.constant.int -2 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[VAR2:.*]] = torch.aten.transpose.int %[[ARG0:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> -// CHECK: %[[VAR3:.*]] = torch.aten.unsqueeze %[[VAR2:.*]], %[[INTM2:.*]] : !torch.vtensor<[23,36],f32>, !torch.int -> !torch.vtensor<[23,1,36],f32> -// CHECK: %[[VAR4:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR1:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR5:.*]] = torch.aten.squeeze.dim %[[VAR4:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32> -// CHECK: %[[VAR6:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR0:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR7:.*]] = torch.aten.squeeze.dim %[[VAR6:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32> -// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR5:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32> -// CHECK: %[[VAR9:.*]] = torch.aten.unsqueeze %[[VAR7:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32> -// CHECK: %[[VAR10:.*]] = torch.prim.ListConstruct %[[VAR8:.*]], %[[VAR9:.*]] : (!torch.vtensor<[23,19,1],f32>, !torch.vtensor<[23,19,1],f32>) -> !torch.list -// CHECK: %[[VAR11:.*]] = torch.aten.cat %[[VAR10:.*]], %[[INTM1:.*]] : !torch.list, !torch.int -> !torch.vtensor<[23,19,2],f32> -// CHECK: %[[VAR12:.*]] = torch.aten.view_as_complex %[[VAR11:.*]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> -// CHECK: %[[VAR13:.*]] = torch.aten.transpose.int %[[VAR12:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> -// CHECK: return %[[VAR13:.*]] : !torch.vtensor<[19,23],complex> +// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19 +// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23 +// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32> +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> +// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32> +// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list -> !torch.vtensor<[23,19,2],f32> +// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> +// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { %int0 = torch.constant.int 0 %none = torch.constant.none From 535d9c17121a67443f9cbeefe0f7bc4b227e2fcf Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:40:34 +0000 Subject: [PATCH 11/11] Refactor conversion to Linalg --- lib/Conversion/TorchToLinalg/Linear.cpp | 192 ++++++++++++-------- test/Conversion/TorchToLinalg/spectral.mlir | 82 +++++---- 2 files changed, 154 insertions(+), 120 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 850e292fbbab..6dcff775ef39 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1376,57 +1376,38 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { namespace { -/// From -/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp -/// /// Creates coefficients based on DFT definition, see /// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. -Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType, - bool isRealPart) { +Value getDFTMatmulCoeff(OpBuilder b, Location loc, + RankedTensorType matrixType) { + + ComplexType complexTy = llvm::cast(matrixType.getElementType()); + mlir::FloatType floatType = + llvm::cast(complexTy.getElementType()); + // scale = 2 * pi / N double scale = 2 * M_PI / matrixType.getDimSize(0); - SmallVector values; - assert(matrixType.getRank() == 2 && "expected 2D matrix"); + SmallVector> values; for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { double v = scale * i * j; - v = isRealPart ? cos(v) : -sin(v); - values.push_back(b.getF32FloatAttr(v)); + double realV = cos(v); + double imagV = -sin(v); + + bool unused; + APFloat real(realV); + real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + APFloat imag(imagV); + imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + + values.push_back(std::complex(real, imag)); } } return b.create( - loc, matrixType, DenseFPElementsAttr::get(matrixType, values)); -} - -/// From -/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp -Value createLinalgMatmulOnTensors(OpBuilder b, Location loc, - RankedTensorType resultType, Value lhs, - Value rhs) { - Value zero = b.create( - loc, b.getZeroAttr(resultType.getElementType())); - Value emptyTensor = b.create( - loc, resultType.getShape(), resultType.getElementType(), - /*dyn_size=*/ValueRange{}); - Value zeroTensor = - b.create(loc, zero, emptyTensor).getResult(0); - - switch (llvm::cast(lhs.getType()).getRank()) { - case 1: - return b - .create(loc, TypeRange{resultType}, - ValueRange{lhs, rhs}, ValueRange{zeroTensor}) - .getResult(0); - case 2: - return b - .create(loc, TypeRange{resultType}, - ValueRange{lhs, rhs}, ValueRange{zeroTensor}) - .getResult(0); - default: - assert(false && "unhandled matmul type"); - return Value(); - } + loc, matrixType, DenseElementsAttr::get(matrixType, values)); } struct ConvertAtenFftRfftOp final : OpConversionPattern { @@ -1461,69 +1442,120 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern { return rewriter.notifyMatchFailure( op, "unsupported: only ranked tensors are supported"); } - if (!inputType.hasStaticShape() || inputType.getRank() > 2) { - return rewriter.notifyMatchFailure( - op, "unsupported: only static 1D or 2D FFT is supported"); - } const ArrayRef inputShape = inputType.getShape(); dim += dim < 0 ? inputShape.size() : 0; const int64_t fftLength = inputShape[dim]; + if (fftLength == ShapedType::kDynamic) { + return rewriter.notifyMatchFailure( + op, "unsupported: FFT signal length must be static"); + } const int64_t rank = inputType.getRank(); const int64_t lastDim = rank - 1; const int64_t outputFftDim = fftLength / 2 + 1; const bool needTranspose = dim != lastDim; - RankedTensorType newResultType = llvm::cast( - getTypeConverter()->convertType(op.getType())); - llvm::SmallVector componentShape(newResultType.getShape()); - // Transpose if FFT dimension is not the last one llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); std::swap(perms[dim], perms[lastDim]); if (needTranspose) { self = transposeValue(loc, self, perms, rewriter); - for (size_t i = 0; i < componentShape.size(); i++) { - componentShape[i] = newResultType.getShape()[perms[i]]; - } } - RankedTensorType matrixType = RankedTensorType::get( - {fftLength, outputFftDim}, inputType.getElementType()); - - RankedTensorType componentsType = - RankedTensorType::get(componentShape, inputType.getElementType()); - - Value realMatrix = - getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true); - Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType, - self, realMatrix); - - Value imagMatrix = - getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false); - Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType, - self, imagMatrix); - - // Pack components into a complex tensor - Type elementType = newResultType.getElementType(); - auto toComplexBody = [&](OpBuilder &b, Location loc, - ValueRange payloadArgs) { - Value realElem = payloadArgs[0]; - Value imagElem = payloadArgs[1]; - Value complexElem = - b.create(loc, elementType, realElem, imagElem); - b.create(loc, complexElem); - }; - Value complexRes = torch_to_linalg::createElementwiseLinalgGeneric( - rewriter, loc, {real, imag}, elementType, toComplexBody); + RankedTensorType newResultType = llvm::cast( + getTypeConverter()->convertType(op.getType())); + ComplexType complexElemType = + llvm::cast(newResultType.getElementType()); + Type elemType = complexElemType.getElementType(); + + // coeffMatrix : tensor> + RankedTensorType coeffType = + RankedTensorType::get({fftLength, outputFftDim}, complexElemType); + // coeffMatrix(n,m) = cos(2 pi n m / N) - j sin(2 pi n m / N) + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, coeffType); + + // #matmul_trait = { + // indexing_maps = [ + // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, f)>, + // affine_map<(d_0, ... d_m, f, o) -> (f, o)>, + // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, o)> + // ], + // iterator_types = ["parallel", ..., "parallel", "reduction", "parallel"] + // } + // linalg.generic #matmul_trait + // ins(%A, %B : tensor, + // tensor>) + // outs(%C : tensor>) { + // ^bb0(%a: f32, %b: complex, %c: complex) : + // %re = complex.re %b : f32 + // %im = complex.im %b : f32 + // %mulre = arith.mulf %a, %re: f32 + // %mulim = arith.mulf %a, %im: f32 + // %mulcplx = complex.create %mulre, %mulim : complex + // %add = complex.add %c, %mulcplx: complex + // linalg.yield %add : complex + // } -> (tensor>) + + Value lhs = self; + Value rhs = coeffMatrix; + RankedTensorType lhsType = llvm::cast(lhs.getType()); + ArrayRef lhsShape(lhsType.getShape()); + ArrayRef rhsShape(coeffType.getShape()); + + unsigned batchRank = lhsShape.size() - 1; + + SmallVector lhsExpr; + SmallVector rhsExpr; + SmallVector outExpr; + SmallVector iteratorTypes( + batchRank, utils::IteratorType::parallel); + SmallVector resultShape; + for (unsigned i = 0; i < batchRank; i++) { + lhsExpr.push_back(rewriter.getAffineDimExpr(i)); + outExpr.push_back(rewriter.getAffineDimExpr(i)); + resultShape.push_back(getDimOp(rewriter, loc, lhs, i)); + } + unsigned fIdx = batchRank, oIdx = batchRank + 1; + lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(fIdx)}); + rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(fIdx), + rewriter.getAffineDimExpr(oIdx)}); + outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(oIdx)}); + resultShape.insert(resultShape.end(), + {getDimOp(rewriter, loc, rhs, rhsShape.size() - 1)}); + + Value zeroTensor = + createZeroInitTensor(rewriter, loc, resultShape, complexElemType); + auto indexingMaps = AffineMap::inferFromExprList( + {lhsExpr, rhsExpr, outExpr}, rewriter.getContext()); + iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::reduction, + utils::IteratorType::parallel}); + + Value complexRes = + rewriter + .create( + loc, zeroTensor.getType(), + /*inputs=*/ValueRange{lhs, rhs}, + /*outputs=*/zeroTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], res = args[2]; + Value re = b.create(loc, elemType, r); + Value im = b.create(loc, elemType, r); + Value mulRe = b.create(loc, l, re); + Value mulIm = b.create(loc, l, im); + Value mulCplx = b.create( + loc, complexElemType, mulRe, mulIm); + Value add = b.create(loc, mulCplx, res); + b.create(loc, add); + }) + .getResult(0); // Transpose back if (needTranspose) { complexRes = transposeValue(loc, complexRes, perms, rewriter); } - rewriter.replaceOpWithNewOp(op, newResultType, complexRes); + rewriter.replaceOp(op, complexRes); return success(); } }; diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir index af0e4f4f6299..abd45183bd84 100644 --- a/test/Conversion/TorchToLinalg/spectral.mlir +++ b/test/Conversion/TorchToLinalg/spectral.mlir @@ -1,25 +1,28 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> + // CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( -// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { -// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32> -// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> -// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<16x5xf32> -// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[REAL_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<16x5xf32> -// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[IMAG_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32> -// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<16x5xcomplex> -// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<16x5xf32>, tensor<16x5xf32>) outs(%[[EMPTY_2:.*]] : tensor<16x5xcomplex>) { -// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_2:.*]]: f32, %[[OUT:.*]]: complex): -// CHECK: %[[ELEM_COMPLEX:.*]] = complex.create %[[IN:.*]], %[[IN_2:.*]] : complex -// CHECK: linalg.yield %[[ELEM_COMPLEX:.*]] : complex +// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex> +// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> +// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex> +// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex>) -> tensor<16x5xcomplex> +// CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex>) outs(%[[VAR2]] : tensor<16x5xcomplex>) { +// CHECK: ^bb0(%in: f32, %in_1: complex, %out: complex): +// CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex +// CHECK: %[[VAR6:.*]] = complex.im %in_1 : complex +// CHECK: %[[VAR7:.*]] = arith.mulf %in, %[[VAR5]] : f32 +// CHECK: %[[VAR8:.*]] = arith.mulf %in, %[[VAR6]] : f32 +// CHECK: %[[VAR9:.*]] = complex.create %[[VAR7]], %[[VAR8]] : complex +// CHECK: %[[VAR10:.*]] = complex.add %[[VAR9]], %out : complex +// CHECK: linalg.yield %[[VAR10]] : complex // CHECK: } -> tensor<16x5xcomplex> -// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[COMPLEX:.*]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> -// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[16,5],complex> +// CHECK: %[[VAR4:.*]] = torch_c.from_builtin_tensor %[[VAR3]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { %int-1 = torch.constant.int -1 @@ -31,29 +34,28 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> // ----- // CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( -// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { -// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32> -// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> -// CHECK-DAG: %[[EMPTY_0:.*]] = tensor.empty() : tensor<23x36xf32> -// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[INPUT:.*]] : tensor<36x23xf32>) outs(%[[EMPTY_0:.*]] : tensor<23x36xf32>) permutation = [1, 0] -// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<23x19xf32> -// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[REAL_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_0:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<23x19xf32> -// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_2:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[IMAG_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32> -// CHECK: %[[EMPTY_3:.*]] = tensor.empty() : tensor<23x19xcomplex> -// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<23x19xf32>, tensor<23x19xf32>) outs(%[[EMPTY_3:.*]] : tensor<23x19xcomplex>) { -// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_3:.*]]: f32, %[[OUT:.*]]: complex): -// CHECK: %[[EMPTY_02:.*]] = complex.create %[[IN:.*]], %[[IN_3:.*]] : complex -// CHECK: linalg.yield %[[EMPTY_02:.*]] : complex +// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xcomplex> +// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> +// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32> +// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0] +// CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex> +// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex>) -> tensor<23x19xcomplex> +// CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex>) outs(%[[VAR3]] : tensor<23x19xcomplex>) { +// CHECK: ^bb0(%in: f32, %in_2: complex, %out: complex): +// CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex +// CHECK: %[[VAR8:.*]] = complex.im %in_2 : complex +// CHECK: %[[VAR9:.*]] = arith.mulf %in, %[[VAR7]] : f32 +// CHECK: %[[VAR10:.*]] = arith.mulf %in, %[[VAR8]] : f32 +// CHECK: %[[VAR11:.*]] = complex.create %[[VAR9]], %[[VAR10]] : complex +// CHECK: %[[VAR12:.*]] = complex.add %[[VAR11]], %out : complex +// CHECK: linalg.yield %[[VAR12]] : complex // CHECK: } -> tensor<23x19xcomplex> -// CHECK: %[[EMPTY_4:.*]] = tensor.empty() : tensor<19x23xcomplex> -// CHECK: %[[TRANSPOSED_2:.*]] = linalg.transpose ins(%[[COMPLEX:.*]] : tensor<23x19xcomplex>) outs(%[[EMPTY_4:.*]] : tensor<19x23xcomplex>) permutation = [1, 0] -// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_2:.*]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> -// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[19,23],complex> +// CHECK-DAG: %[[VAR5:.*]] = tensor.empty() : tensor<19x23xcomplex> +// CHECK: %[[TRANSPOSED_1:.*]] = linalg.transpose ins(%[[VAR4]] : tensor<23x19xcomplex>) outs(%[[VAR5]] : tensor<19x23xcomplex>) permutation = [1, 0] +// CHECK: %[[VAR6:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_1]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { %int0 = torch.constant.int 0 %none = torch.constant.none