diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cb49fa97b86a..7dc0d6ebb237 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -12427,6 +12427,198 @@ class DecomposeAtenRoundDecimalsOp }; } // namespace +namespace { +class DecomposeAtenAsStridedOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAsStridedOp op, + PatternRewriter &rewriter) const override { + + // The `aten.as_strided` operation is decomposed into a series of + // operations that compute the indices based on the provided sizes and + // strides, and then index into the flattened input tensor as follows: + + // input_flat = input.view(-1) + // + // for dim, s in enumerate(self.size): + // arange = torch.arange(s) + // view_shape = [] + // for i in range(len(self.size)): + // if i == dim: + // view_shape.append(-1) + // else: + // view_shape.append(1) + // arange = arange.view(view_shape) + // if dim != 0: + // idx = idx + arange * self.stride[dim] + // + // # Flatten indices and add offset + // final_indices = idx.reshape(-1) + self.storage_offset + // + // # Index the flattened input tensor + // output = input_flat[final_indices] + // + // # Reshape to desired output size + // return output.view(self.size) + + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + Value input = op.getSelf(); + auto inputType = dyn_cast(input.getType()); + + if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown()) + return rewriter.notifyMatchFailure(op, "input must have known sizes"); + + SmallVector sizesInts; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts))) + return rewriter.notifyMatchFailure( + op, "sizes must be a list of constant ints"); + + SmallVector stridesInts; + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stridesInts))) + return rewriter.notifyMatchFailure( + op, "strides must be a list of constant ints"); + + int64_t storageOffset = 0; + if (!isa(op.getStorageOffset().getType())) { + if (!matchPattern(op.getStorageOffset(), + m_TorchConstantInt(&storageOffset))) + return rewriter.notifyMatchFailure( + op, "storage_offset must be a constant integer"); + } + + ArrayRef inputSizes = inputType.getSizes(); + int64_t inputRank = inputSizes.size(); + int64_t resultRank = sizesInts.size(); + + Value cstZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + if (inputRank > 1) { + // If the input is not a 1-d tensor, we need to flatten it + // to a 1D tensor before applying the strided indexing. + int64_t flattenedInputSize = 1; + for (int64_t size : inputSizes) + flattenedInputSize *= size; + + auto flattenedInputTy = + cast(inputType.getWithSizesAndDtype( + {flattenedInputSize}, inputType.getOptionalDtype())); + + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank - 1)); + input = rewriter.create(loc, flattenedInputTy, + input, cstZero, end); + } + + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value cstMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + + SmallVector viewShapeInts(resultRank, 1); + SmallVector viewShapeListElems(resultRank, cstOne); + + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); + Value finalIndices; + for (unsigned dim = 0; dim < sizesInts.size(); dim++) { + int64_t size = sizesInts[dim]; + Value cstNone = rewriter.create(loc); + Value end = + rewriter.create(loc, rewriter.getI64IntegerAttr(size)); + + auto arangeType = + ValueTensorType::get(context, llvm::ArrayRef(size), si64Type); + Value index = rewriter.create( + loc, arangeType, end, cstNone, cstNone, cstNone, cstNone); + + // Set the current dimension to -1 for broadcasting + viewShapeInts[dim] = -1; + viewShapeListElems[dim] = cstMinusOne; + + Value viewShapeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + viewShapeListElems); + + auto viewType = ValueTensorType::get( + context, llvm::ArrayRef(viewShapeInts), si64Type); + index = rewriter.create(loc, viewType, index, viewShapeList); + + // Multiply the index with the stride for the current dimension + Value cstStride = rewriter.create( + loc, rewriter.getI64IntegerAttr(stridesInts[dim])); + index = rewriter.create(loc, viewType, index, cstStride); + + // Reset the current dimension to 1 for the next iteration + viewShapeInts[dim] = 1; + viewShapeListElems[dim] = cstOne; + + if (dim == 0) { + finalIndices = index; + continue; + } + + // calculate common shape for broadcast + SmallVector broadcastShape; + SmallVector broadcastShapeValue; + computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape, + broadcastShapeValue); + Type broadcastType = ValueTensorType::get( + context, llvm::ArrayRef(broadcastShape), si64Type); + + finalIndices = rewriter.create( + loc, broadcastType, finalIndices, index, cstOne); + } + + int64_t flattenedResultSize = 1; + for (int64_t size : sizesInts) + flattenedResultSize *= size; + + // Flattening the indices and adding the storage offset + finalIndices = rewriter.create( + loc, + ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize), + si64Type), + finalIndices, cstZero, cstMinusOne); // -1 means flatten all + + if (storageOffset != 0) { + Value cstStorageOffset = rewriter.create( + loc, rewriter.getI64IntegerAttr(storageOffset)); + finalIndices = rewriter.create( + loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne); + } + + // Index the flattened input tensor + Type listElemType = + inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Value indicesList = rewriter.create( + loc, Torch::ListType::get(listElemType), + SmallVector{finalIndices}); + + auto flattenedResultTy = + ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize), + inputType.getOptionalDtype()); + Value result = rewriter.create(loc, flattenedResultTy, + input, indicesList); + + // Reshape the result to the desired output size + SmallVector sizesIntsValues; + for (int64_t size : sizesInts) { + sizesIntsValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(size))); + } + Value resultSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + sizesIntsValues); + result = + rewriter.create(loc, op.getType(), result, resultSizeList); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -12750,6 +12942,7 @@ class DecomposeComplexOpsPass patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.setUseTopDownTraversal(true); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6d6ed9cad50d..e5e5ffc6d39f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -589,6 +589,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 25d2a01c980b..0e9f0332140f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -486,17 +486,6 @@ "ViewSizeFromOtherTensor_basic", "ViewDtypeStaticModule_basic", "WeightNormInterfaceModule_basic", - # Error: `aten.as_strided` op is not supported - "ChunkListUnpackDynamic_Module_basic", - "ChunkListUnpackUnevenDynamic_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", - "SplitWithSizes_Module_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", @@ -526,8 +515,6 @@ "ReflectionPad3dModuleRight_basic", "ReflectionPad3dModuleFront_basic", "ReflectionPad3dModuleBack_basic", - # RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule - "NativeGroupNormModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -982,6 +969,8 @@ "NativeGroupNormModule_basic", "AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic", "MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic", + "AtenAsStridedModule_basic", + "AtenAsStridedNoStorageOffsetModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3949,6 +3938,19 @@ "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "ReplicationPad1dModule_2DInput_basic", "ReplicationPad1dModule_3DInput_basic", + "AtenAsStridedModule_basic", + "AtenAsStridedNoStorageOffsetModule_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "NativeGroupNormModule_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SplitWithSizes_Module_basic", } ONNX_TOSA_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 1ad698db9cc1..2408ca329ebe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6730,3 +6730,48 @@ def forward(self, x): @register_test_case(module_factory=lambda: Aten_AssertScalar()) def Aten_AssertScalar_basic(module, tu: TestUtils): module.forward(torch.tensor(4)) + + +# ============================================================================== + + +class AtenAsStridedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.as_strided( + x, size=(2, 2), stride=(3, 3), storage_offset=1 + ) + + +@register_test_case(module_factory=lambda: AtenAsStridedModule()) +def AtenAsStridedModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6)) + + +class AtenAsStridedNoStorageOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([12, 13], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.as_strided(x, size=(3, 4), stride=(2, 5)) + + +@register_test_case(module_factory=lambda: AtenAsStridedNoStorageOffsetModule()) +def AtenAsStridedNoStorageOffsetModule_basic(module, tu: TestUtils): + module.forward(torch.randn(12, 13))