Skip to content

[MLIR][TORCH] Add E2E support for aten.as_strided op #4269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

vivekkhandelwal1
Copy link
Collaborator

@vivekkhandelwal1 vivekkhandelwal1 commented Jul 11, 2025

This commit adds the e2e support for the aten.as_strided op by decomposing it into a series of other torch operations.

Fixes #4191.

The failing tests for Tosa config are tracked by #4272.

@vivekkhandelwal1
Copy link
Collaborator Author

The CI will fail for the nightly run. That's because of a different issue related to the PyTorch integration. I will fix that in a separate patch.

This commit adds the e2e support for the aten.as_strided op by
decomposing it into a series of other torch operations.

Signed-off-by: Vivek Khandelwal <[email protected]>
@vivekkhandelwal1
Copy link
Collaborator Author

@sjarus @justin-ngo-arm Some of the tests are failing for Tosa config since I have added a decomp for the as_strided op.

@justin-ngo-arm
Copy link
Contributor

@vivekkhandelwal1 Can you let me know which tests are failing for TOSA? Do you want us to fix those for you, or will you fix them/add them to XFail?

@vivekkhandelwal1
Copy link
Collaborator Author

@vivekkhandelwal1 Can you let me know which tests are failing for TOSA? Do you want us to fix those for you, or will you fix them/add them to XFail?

@justin-ngo-arm In this patch, I have added a decomposition for the as_strided op. As a result, the tests that were passing earlier for the Tosa config are now failing because, instead of being lowered in the TorchToTosa pipeline, the op is being decomposed and cannot be converted to Tosa. Hence, the tests that passed when you added the lowering for as_strided, some of them are now failing.

In total, 11 tests are failing; they are as follows:

ChunkListUnpackDynamic_Module_basic
ChunkListUnpackUnevenDynamic_Module_basic
ChunkListUnpackUneven_Module_basic
ChunkListUnpack_Module_basic
NativeGroupNormModule_basic
SplitTensorGetItem_Module_basic
SplitTensorLastSmallerModule_basic
SplitTensorListUnpackModule_basic
SplitTensorNegativeDimModule_basic
SplitWithSizesListUnpackModule_basic
SplitWithSizes_Module_basic

I don't know how much time it would take to fix these tests, and that has to be done in a separate patch. So, if it's not blocking you, then can I just xfail them in this patch? And, you may raise a separate PR with the fix where the tests can be re-enabled.

@justin-ngo-arm
Copy link
Contributor

@vivekkhandelwal1 Marking them with XFail works for me. Can you do me a favor by creating an issue with a list of those failed tests and assign them to me please? That will serve as a better reminder for me to track this, and I will fix them when I have time to do so. Thanks a lot!

@vivekkhandelwal1
Copy link
Collaborator Author

@vivekkhandelwal1 Marking them with XFail works for me. Can you do me a favor by creating an issue with a list of those failed tests and assign them to me please? That will serve as a better reminder for me to track this, and I will fix them when I have time to do so. Thanks a lot!

Sure, I will do that.

@vivekkhandelwal1
Copy link
Collaborator Author

In this patch, I have added a decomposition for the as_strided op. As a result, the tests that were passing earlier for the Tosa config are now failing because, instead of being lowered in the TorchToTosa pipeline, the op is being decomposed and cannot be converted to Tosa. Hence, the tests that passed when you added the lowering for as_strided, some of them are now failing.

In total, 11 tests are failing; they are as follows:

ChunkListUnpackDynamic_Module_basic
ChunkListUnpackUnevenDynamic_Module_basic
ChunkListUnpackUneven_Module_basic
ChunkListUnpack_Module_basic
NativeGroupNormModule_basic
SplitTensorGetItem_Module_basic
SplitTensorLastSmallerModule_basic
SplitTensorListUnpackModule_basic
SplitTensorNegativeDimModule_basic
SplitWithSizesListUnpackModule_basic
SplitWithSizes_Module_basic

Hi @justin-ngo-arm, I have created an issue here #4272.

@justin-ngo-arm
Copy link
Contributor

Thank you @vivekkhandelwal1! I really appreciate it.

Copy link
Contributor

@praveen-g-ctt praveen-g-ctt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

module.forward(torch.randn(4, 5, 6))


class AtenAsStridedNoStorageOffsetModule(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

A clarification: Is there a general guidelines for adding tests? Are e2e tests sufficient or LIT tests are required as well?

Thanks!

@sjarus
Copy link
Collaborator

sjarus commented Jul 16, 2025

Looks like @justin-ngo-arm has this covered @vivekkhandelwal1 . I tried to add him on the issue, but for some reason his ID doesn't show up as someone who could be added as reviewer on the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

(TorchToLinalg) Support for lowering torch.aten.as_strided
6 participants