Skip to content

[MLIR][TORCH] Add E2E support for aten.as_strided op #4269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12427,6 +12427,198 @@ class DecomposeAtenRoundDecimalsOp
};
} // namespace

namespace {
class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
public:
using OpRewritePattern<AtenAsStridedOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAsStridedOp op,
PatternRewriter &rewriter) const override {

// The `aten.as_strided` operation is decomposed into a series of
// operations that compute the indices based on the provided sizes and
// strides, and then index into the flattened input tensor as follows:

// input_flat = input.view(-1)
//
// for dim, s in enumerate(self.size):
// arange = torch.arange(s)
// view_shape = []
// for i in range(len(self.size)):
// if i == dim:
// view_shape.append(-1)
// else:
// view_shape.append(1)
// arange = arange.view(view_shape)
// if dim != 0:
// idx = idx + arange * self.stride[dim]
//
// # Flatten indices and add offset
// final_indices = idx.reshape(-1) + self.storage_offset
//
// # Index the flattened input tensor
// output = input_flat[final_indices]
//
// # Reshape to desired output size
// return output.view(self.size)

Location loc = op.getLoc();
MLIRContext *context = op->getContext();
Value input = op.getSelf();
auto inputType = dyn_cast<BaseTensorType>(input.getType());

if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "input must have known sizes");

SmallVector<int64_t> sizesInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts)))
return rewriter.notifyMatchFailure(
op, "sizes must be a list of constant ints");

SmallVector<int64_t> stridesInts;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stridesInts)))
return rewriter.notifyMatchFailure(
op, "strides must be a list of constant ints");

int64_t storageOffset = 0;
if (!isa<Torch::NoneType>(op.getStorageOffset().getType())) {
if (!matchPattern(op.getStorageOffset(),
m_TorchConstantInt(&storageOffset)))
return rewriter.notifyMatchFailure(
op, "storage_offset must be a constant integer");
}

ArrayRef<int64_t> inputSizes = inputType.getSizes();
int64_t inputRank = inputSizes.size();
int64_t resultRank = sizesInts.size();

Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
if (inputRank > 1) {
// If the input is not a 1-d tensor, we need to flatten it
// to a 1D tensor before applying the strided indexing.
int64_t flattenedInputSize = 1;
for (int64_t size : inputSizes)
flattenedInputSize *= size;

auto flattenedInputTy =
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
{flattenedInputSize}, inputType.getOptionalDtype()));

Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenedInputTy,
input, cstZero, end);
}

Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value cstMinusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));

SmallVector<int64_t> viewShapeInts(resultRank, 1);
SmallVector<Value> viewShapeListElems(resultRank, cstOne);

auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
Value finalIndices;
for (unsigned dim = 0; dim < sizesInts.size(); dim++) {
int64_t size = sizesInts[dim];
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
Value end =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(size));

auto arangeType =
ValueTensorType::get(context, llvm::ArrayRef(size), si64Type);
Value index = rewriter.create<Torch::AtenArangeOp>(
loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);

// Set the current dimension to -1 for broadcasting
viewShapeInts[dim] = -1;
viewShapeListElems[dim] = cstMinusOne;

Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
viewShapeListElems);

auto viewType = ValueTensorType::get(
context, llvm::ArrayRef(viewShapeInts), si64Type);
index = rewriter.create<AtenViewOp>(loc, viewType, index, viewShapeList);

// Multiply the index with the stride for the current dimension
Value cstStride = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(stridesInts[dim]));
index = rewriter.create<AtenMulScalarOp>(loc, viewType, index, cstStride);

// Reset the current dimension to 1 for the next iteration
viewShapeInts[dim] = 1;
viewShapeListElems[dim] = cstOne;

if (dim == 0) {
finalIndices = index;
continue;
}

// calculate common shape for broadcast
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape,
broadcastShapeValue);
Type broadcastType = ValueTensorType::get(
context, llvm::ArrayRef(broadcastShape), si64Type);

finalIndices = rewriter.create<AtenAddTensorOp>(
loc, broadcastType, finalIndices, index, cstOne);
}

int64_t flattenedResultSize = 1;
for (int64_t size : sizesInts)
flattenedResultSize *= size;

// Flattening the indices and adding the storage offset
finalIndices = rewriter.create<AtenFlattenUsingIntsOp>(
loc,
ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
si64Type),
finalIndices, cstZero, cstMinusOne); // -1 means flatten all

if (storageOffset != 0) {
Value cstStorageOffset = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(storageOffset));
finalIndices = rewriter.create<AtenAddScalarOp>(
loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne);
}

// Index the flattened input tensor
Type listElemType =
inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Value indicesList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(listElemType),
SmallVector<Value>{finalIndices});

auto flattenedResultTy =
ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
inputType.getOptionalDtype());
Value result = rewriter.create<AtenIndexTensorOp>(loc, flattenedResultTy,
input, indicesList);

// Reshape the result to the desired output size
SmallVector<Value> sizesIntsValues;
for (int64_t size : sizesInts) {
sizesIntsValues.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(size)));
}
Value resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
sizesIntsValues);
result =
rewriter.create<AtenViewOp>(loc, op.getType(), result, resultSizeList);

rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -12750,6 +12942,7 @@ class DecomposeComplexOpsPass
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);

GreedyRewriteConfig config;
config.setUseTopDownTraversal(true);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLogaddexpOp>();
target.addIllegalOp<AtenLogaddexp2Op>();
target.addIllegalOp<AtenKlDivOp>();
target.addIllegalOp<AtenAsStridedOp>();

for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
Expand Down
28 changes: 15 additions & 13 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,17 +486,6 @@
"ViewSizeFromOtherTensor_basic",
"ViewDtypeStaticModule_basic",
"WeightNormInterfaceModule_basic",
# Error: `aten.as_strided` op is not supported
"ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
Expand Down Expand Up @@ -526,8 +515,6 @@
"ReflectionPad3dModuleRight_basic",
"ReflectionPad3dModuleFront_basic",
"ReflectionPad3dModuleBack_basic",
# RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule
"NativeGroupNormModule_basic",
}

FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
Expand Down Expand Up @@ -982,6 +969,8 @@
"NativeGroupNormModule_basic",
"AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
"AtenAsStridedModule_basic",
"AtenAsStridedNoStorageOffsetModule_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -3949,6 +3938,19 @@
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"ReplicationPad1dModule_2DInput_basic",
"ReplicationPad1dModule_3DInput_basic",
"AtenAsStridedModule_basic",
"AtenAsStridedNoStorageOffsetModule_basic",
"ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"NativeGroupNormModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
}

ONNX_TOSA_CRASHING_SET = {
Expand Down
45 changes: 45 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6730,3 +6730,48 @@ def forward(self, x):
@register_test_case(module_factory=lambda: Aten_AssertScalar())
def Aten_AssertScalar_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))


# ==============================================================================


class AtenAsStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([4, 5, 6], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.as_strided(
x, size=(2, 2), stride=(3, 3), storage_offset=1
)


@register_test_case(module_factory=lambda: AtenAsStridedModule())
def AtenAsStridedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6))


class AtenAsStridedNoStorageOffsetModule(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

A clarification: Is there a general guidelines for adding tests? Are e2e tests sufficient or LIT tests are required as well?

Thanks!

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([12, 13], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.as_strided(x, size=(3, 4), stride=(2, 5))


@register_test_case(module_factory=lambda: AtenAsStridedNoStorageOffsetModule())
def AtenAsStridedNoStorageOffsetModule_basic(module, tu: TestUtils):
module.forward(torch.randn(12, 13))
Loading