diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 44bf8ab2e0d4..2d03438967db 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12980,6 +12980,32 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } +def Torch_AtenFftRfftOp : Torch_Op<"aten.fft_rfft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$n, + Torch_IntType:$dim, + AnyTorchOptionalStringType:$norm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFftRfftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenFftRfftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index cf31c8f9735a..dd13c40facf8 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -20,6 +20,8 @@ namespace Torch { int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); +Value toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput, Torch::IntType intType); bool getListConstructElements(Value v, SmallVectorImpl &elems); /// Returns the index indicated by `v` for a list of given `length`. /// If the index is negative, it is adjusted to `length` + `v`. diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index a4962d12abdc..6dcff775ef39 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -11,6 +11,7 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Matchers.h" @@ -1373,6 +1374,194 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { }; } // namespace +namespace { + +/// Creates coefficients based on DFT definition, see +/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +Value getDFTMatmulCoeff(OpBuilder b, Location loc, + RankedTensorType matrixType) { + + ComplexType complexTy = llvm::cast(matrixType.getElementType()); + mlir::FloatType floatType = + llvm::cast(complexTy.getElementType()); + + // scale = 2 * pi / N + double scale = 2 * M_PI / matrixType.getDimSize(0); + + SmallVector> values; + for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { + for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { + double v = scale * i * j; + double realV = cos(v); + double imagV = -sin(v); + + bool unused; + APFloat real(realV); + real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + APFloat imag(imagV); + imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + + values.push_back(std::complex(real, imag)); + } + } + return b.create( + loc, matrixType, DenseElementsAttr::get(matrixType, values)); +} + +struct ConvertAtenFftRfftOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenFftRfftOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + + int64_t dim; + auto dimVal = op.getDim(); + if (isa(dimVal.getType())) { + dim = -1; + } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constant"); + } + + if (!isa(op.getN().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); + } + + if (!isa(op.getNorm().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); + } + + RankedTensorType inputType = + cast(adaptor.getSelf().getType()); + if (!inputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "unsupported: only ranked tensors are supported"); + } + + const ArrayRef inputShape = inputType.getShape(); + dim += dim < 0 ? inputShape.size() : 0; + + const int64_t fftLength = inputShape[dim]; + if (fftLength == ShapedType::kDynamic) { + return rewriter.notifyMatchFailure( + op, "unsupported: FFT signal length must be static"); + } + const int64_t rank = inputType.getRank(); + const int64_t lastDim = rank - 1; + const int64_t outputFftDim = fftLength / 2 + 1; + const bool needTranspose = dim != lastDim; + + // Transpose if FFT dimension is not the last one + llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); + std::swap(perms[dim], perms[lastDim]); + if (needTranspose) { + self = transposeValue(loc, self, perms, rewriter); + } + + RankedTensorType newResultType = llvm::cast( + getTypeConverter()->convertType(op.getType())); + ComplexType complexElemType = + llvm::cast(newResultType.getElementType()); + Type elemType = complexElemType.getElementType(); + + // coeffMatrix : tensor> + RankedTensorType coeffType = + RankedTensorType::get({fftLength, outputFftDim}, complexElemType); + // coeffMatrix(n,m) = cos(2 pi n m / N) - j sin(2 pi n m / N) + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, coeffType); + + // #matmul_trait = { + // indexing_maps = [ + // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, f)>, + // affine_map<(d_0, ... d_m, f, o) -> (f, o)>, + // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, o)> + // ], + // iterator_types = ["parallel", ..., "parallel", "reduction", "parallel"] + // } + // linalg.generic #matmul_trait + // ins(%A, %B : tensor, + // tensor>) + // outs(%C : tensor>) { + // ^bb0(%a: f32, %b: complex, %c: complex) : + // %re = complex.re %b : f32 + // %im = complex.im %b : f32 + // %mulre = arith.mulf %a, %re: f32 + // %mulim = arith.mulf %a, %im: f32 + // %mulcplx = complex.create %mulre, %mulim : complex + // %add = complex.add %c, %mulcplx: complex + // linalg.yield %add : complex + // } -> (tensor>) + + Value lhs = self; + Value rhs = coeffMatrix; + RankedTensorType lhsType = llvm::cast(lhs.getType()); + ArrayRef lhsShape(lhsType.getShape()); + ArrayRef rhsShape(coeffType.getShape()); + + unsigned batchRank = lhsShape.size() - 1; + + SmallVector lhsExpr; + SmallVector rhsExpr; + SmallVector outExpr; + SmallVector iteratorTypes( + batchRank, utils::IteratorType::parallel); + SmallVector resultShape; + for (unsigned i = 0; i < batchRank; i++) { + lhsExpr.push_back(rewriter.getAffineDimExpr(i)); + outExpr.push_back(rewriter.getAffineDimExpr(i)); + resultShape.push_back(getDimOp(rewriter, loc, lhs, i)); + } + unsigned fIdx = batchRank, oIdx = batchRank + 1; + lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(fIdx)}); + rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(fIdx), + rewriter.getAffineDimExpr(oIdx)}); + outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(oIdx)}); + resultShape.insert(resultShape.end(), + {getDimOp(rewriter, loc, rhs, rhsShape.size() - 1)}); + + Value zeroTensor = + createZeroInitTensor(rewriter, loc, resultShape, complexElemType); + auto indexingMaps = AffineMap::inferFromExprList( + {lhsExpr, rhsExpr, outExpr}, rewriter.getContext()); + iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::reduction, + utils::IteratorType::parallel}); + + Value complexRes = + rewriter + .create( + loc, zeroTensor.getType(), + /*inputs=*/ValueRange{lhs, rhs}, + /*outputs=*/zeroTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], res = args[2]; + Value re = b.create(loc, elemType, r); + Value im = b.create(loc, elemType, r); + Value mulRe = b.create(loc, l, re); + Value mulIm = b.create(loc, l, im); + Value mulCplx = b.create( + loc, complexElemType, mulRe, mulIm); + Value add = b.create(loc, mulCplx, res); + b.create(loc, add); + }) + .getResult(0); + + // Transpose back + if (needTranspose) { + complexRes = transposeValue(loc, complexRes, perms, rewriter); + } + + rewriter.replaceOp(op, complexRes); + return success(); + } +}; + +} // namespace + void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1387,4 +1576,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 559726f20659..25761f04b882 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10689,6 +10689,50 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fft_rfft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Expected dim in [-rank, rank-1]\"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.aten.add.int %arg2, %10 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %11 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" %2 = torch.aten.ge.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.aten.lt.int %1, %10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %4, %10 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.add.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten._set_item.t %4, %1, %8 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n" @@ -12766,6 +12810,44 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_rfft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int7 = torch.constant.int 7\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 29c176f96afd..64d9c769f755 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9020,6 +9020,157 @@ class DecomposeAtenTopkOp : public OpRewritePattern { }; } // namespace +namespace { + +/// Creates coefficients based on DFT definition, see +/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +/// Even indices of the second dimension are for the real components of the +/// output. Odd indices for the imaginary components. +Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, + ValueTensorType matrixType) { + // scale = 2 * pi / N + double scale = 2 * M_PI / matrixType.getSizes()[0]; + + SmallVector values; + assert(matrixType.getSizes().size() == 2 && "expected 2D matrix"); + for (auto i : llvm::seq(0, matrixType.getSizes()[0])) { + for (auto j : llvm::seq(0, matrixType.getSizes()[1])) { + const bool isImagPart = j % 2; + double v = scale * i * (j / 2); + v = isImagPart ? -sin(v) : cos(v); + values.push_back(rewriter.getF32FloatAttr(v)); + } + } + + return rewriter.create( + loc, matrixType, + DenseElementsAttr::get(matrixType.toBuiltinTensor(), + ArrayRef(values))); +} + +class DecomposeAtenFftRfftOp final : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFftRfftOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + int64_t dim; + auto dimVal = op.getDim(); + if (isa(dimVal.getType())) { + dim = -1; + } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constant"); + } + + if (!isa(op.getN().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); + } + + if (!isa(op.getNorm().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); + } + + BaseTensorType inputType = cast(self.getType()); + + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "unsupported: only ranked tensors are supported"); + } + + const ArrayRef inputShape = inputType.getSizes(); + dim += dim < 0 ? inputShape.size() : 0; + + const int64_t fftLength = inputShape[dim]; + if (fftLength == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "unsupported: input signal length must be known"); + } + const int64_t rank = inputShape.size(); + const int64_t lastDim = rank - 1; + const int64_t outputFftDim = fftLength / 2 + 1; + const bool needTranspose = dim != lastDim; + + auto transposeValue = [](PatternRewriter &rewriter, Location loc, + Value input, int64_t dimA, int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(cast(input.getType()), dimA, + dimB, transposedType))) + return failure(); + Value cstDimA = + rewriter.create(loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = + rewriter.create(loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create(loc, transposedType, + input, cstDimA, cstDimB); + return success(); + }; + + SmallVector lhsShape(inputShape); + // Transpose if FFT dimension is not the last one + if (needTranspose) { + if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) { + return failure(); + } + std::swap(lhsShape[dim], lhsShape[lastDim]); + } + // self : (D_0 x ... x D_m x fftLength) + + Type dtype = inputType.getOptionalDtype(); + + // coeff : (fftLength x outputFftDim*2) + ValueTensorType matrixType = ValueTensorType::get( + op.getContext(), SmallVector{fftLength, outputFftDim * 2}, + dtype); + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType); + + // X = matmul(self, coeff) : (D_0 x ... x D_m x outputFftDim*2) + SmallVector matmulShape(lhsShape.begin(), lhsShape.end() - 1); + matmulShape.push_back(outputFftDim * 2); + ValueTensorType matmulType = + ValueTensorType::get(op.getContext(), matmulShape, dtype); + Value flatRes = + rewriter.create(loc, matmulType, self, coeffMatrix); + + // Y = unflatten(X, -1, [outputFftDim, 2]) + // : (D_0 x ... x D_m x outputFftDim x 2) + // Z = view_as_complex(Y) : complex(D_0 x ... x D_m x outputFftDim) + SmallVector complexResShape(matmulShape); + complexResShape.back() = outputFftDim; + SmallVector unflattenedResShape(complexResShape); + unflattenedResShape.push_back(2); + Type unflattenedResType = + ValueTensorType::get(op.getContext(), unflattenedResShape, dtype); + Value cstMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value unflattenSizes = toIntListConstruct( + rewriter, loc, {outputFftDim, 2}, IntType::get(rewriter.getContext())); + Value unflattenedRes = rewriter.create( + loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes); + Type complexResType = ValueTensorType::get(op.getContext(), complexResShape, + ComplexType::get(dtype)); + Value complexRes = rewriter.create(loc, complexResType, + unflattenedRes); + + // Transpose back + if (needTranspose) { + if (failed(transposeValue(rewriter, loc, complexRes, dim, lastDim, + complexRes))) { + return failure(); + } + } + + rewriter.replaceOp(op, {complexRes}); + + return success(); + } +}; + +} // namespace + namespace { // Decompose `aten.hann_window` into `aten.arange.start`, `aten.mul.Scalar`, // `aten.sin` and `aten.square` or into `aten.ones` in the trivial case @@ -10011,6 +10162,7 @@ class DecomposeComplexOpsPass patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 664bbb2d5d8e..77840d206e02 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -36,6 +36,18 @@ Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { return dim; } +Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput, + Torch::IntType intType) { + SmallVector cstValues; + for (int64_t i : cstInput) { + cstValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + loc, Torch::ListType::get(intType), cstValues); +} + bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 052eceb5ac4a..0c859fd73ac0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -620,6 +620,8 @@ "AtenDiagEmbedOffsetDiag_basic", "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagSumExample_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", @@ -2675,6 +2677,8 @@ "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2b7db059bb42..f883ade14fae 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2192,6 +2192,18 @@ def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]: def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self +@check_shape_function([ + Invocation(TensorOfShape(3, 9, 5), None, -2, None) # Second-last dim +]) +def aten〇fft_rfft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: + dim = (dim + len(self)) if dim < 0 else dim + assert dim >= 0 and dim < len(self), "Expected dim in [-rank, rank-1]" + out: List[int] = [] + for s in self: + out.append(s) + out[dim] = self[dim] // 2 + 1 + return out + @check_shape_function([ Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window. ]) @@ -3732,6 +3744,23 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = else: assert False, "Unsupported dtype" +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex32, torch.complex64, torch.complex128, torch.bfloat16})) +def aten〇fft_rfft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + + + @check_dtype_function([ Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32 Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64 diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ea5070a8c0bb..97f78199ce90 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -951,6 +951,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::hann_window.periodic : (int, bool, int?, int?, Device?, bool?) -> (Tensor)" ) emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py index 8e259fbe0c2a..57a7270f9d09 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py @@ -51,3 +51,43 @@ def forward(self): @register_test_case(module_factory=lambda: AtenHannWindowPeriodicTrueModule()) def AtenHannWindowPeriodicTrueModule_basic(module, tu: TestUtils): module.forward() + + +# ============================================================================== + + +class AtenFftRfft2DLastDim(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([16, 9], torch.float32, True), + ] + ) + def forward(self, input): + return torch.fft.rfft(input, dim=-1) + + +@register_test_case(module_factory=lambda: AtenFftRfft2DLastDim()) +def AtenFftRfft2DLastDim_basic(module, tu: TestUtils): + module.forward(tu.rand(16, 9)) + + +# ============================================================================== + + +class AtenFftRfft2DMiddleDim(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([36, 10], torch.float32, True), + ] + ) + def forward(self, input): + return torch.fft.rfft(input, dim=0) + + +@register_test_case(module_factory=lambda: AtenFftRfft2DMiddleDim()) +def AtenFftRfft2DMiddleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(36, 10)) diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir new file mode 100644 index 000000000000..abd45183bd84 --- /dev/null +++ b/test/Conversion/TorchToLinalg/spectral.mlir @@ -0,0 +1,64 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex> +// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> +// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex> +// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex>) -> tensor<16x5xcomplex> +// CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex>) outs(%[[VAR2]] : tensor<16x5xcomplex>) { +// CHECK: ^bb0(%in: f32, %in_1: complex, %out: complex): +// CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex +// CHECK: %[[VAR6:.*]] = complex.im %in_1 : complex +// CHECK: %[[VAR7:.*]] = arith.mulf %in, %[[VAR5]] : f32 +// CHECK: %[[VAR8:.*]] = arith.mulf %in, %[[VAR6]] : f32 +// CHECK: %[[VAR9:.*]] = complex.create %[[VAR7]], %[[VAR8]] : complex +// CHECK: %[[VAR10:.*]] = complex.add %[[VAR9]], %out : complex +// CHECK: linalg.yield %[[VAR10]] : complex +// CHECK: } -> tensor<16x5xcomplex> +// CHECK: %[[VAR4:.*]] = torch_c.from_builtin_tensor %[[VAR3]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> + +func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { + %int-1 = torch.constant.int -1 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> + return %out : !torch.vtensor<[16,5],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xcomplex> +// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> +// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32> +// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0] +// CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex> +// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex>) -> tensor<23x19xcomplex> +// CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex>) outs(%[[VAR3]] : tensor<23x19xcomplex>) { +// CHECK: ^bb0(%in: f32, %in_2: complex, %out: complex): +// CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex +// CHECK: %[[VAR8:.*]] = complex.im %in_2 : complex +// CHECK: %[[VAR9:.*]] = arith.mulf %in, %[[VAR7]] : f32 +// CHECK: %[[VAR10:.*]] = arith.mulf %in, %[[VAR8]] : f32 +// CHECK: %[[VAR11:.*]] = complex.create %[[VAR9]], %[[VAR10]] : complex +// CHECK: %[[VAR12:.*]] = complex.add %[[VAR11]], %out : complex +// CHECK: linalg.yield %[[VAR12]] : complex +// CHECK: } -> tensor<23x19xcomplex> +// CHECK-DAG: %[[VAR5:.*]] = tensor.empty() : tensor<19x23xcomplex> +// CHECK: %[[TRANSPOSED_1:.*]] = linalg.transpose ins(%[[VAR4]] : tensor<23x19xcomplex>) outs(%[[VAR5]] : tensor<19x23xcomplex>) permutation = [1, 0] +// CHECK: %[[VAR6:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_1]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> +func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> + return %out : !torch.vtensor<[19,23],complex> +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index f938a2637835..bf37a484720f 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -171,3 +171,47 @@ func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.v %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> return %0 : !torch.vtensor<[?],f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5 +// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16 +// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32> +// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32> +// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list -> !torch.vtensor<[16,5,2],f32> +// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> +func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { + %int-1 = torch.constant.int -1 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> + return %out : !torch.vtensor<[16,5],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19 +// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23 +// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32> +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> +// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32> +// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list -> !torch.vtensor<[23,19,2],f32> +// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> +// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> +func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> + return %out : !torch.vtensor<[19,23],complex> +}