forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (llvm#3075)
Decomposition RepeatInterleaveSelfInt with following ops: ```python def my_repeat_interleave(input, repeats, dim=None): if dim is None: # Flatten the input and then repeat return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten() else: # Calculate the shape after repeat expanded_shape = list(input.shape) expanded_shape[dim] *= repeats # Repeat the tensor along the specified dimension repeat_shape = [1] * (input.dim() + 1) repeat_shape[dim + 1] = repeats input = input.unsqueeze(-1) # Tile and then reshape tiled = torch.tile(input, repeat_shape) # Rearrange and reshape repeated = tiled.reshape(*expanded_shape) return repeated ``` I passed the tests of stablehlo and linalg. When testing onnx, strange things happened. In torch-mlir's CI **torch_nightly** and my own environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**. In torch-mlir's CI **torch_stable**, it **failed**. The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result shape should be [120]. ```python class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([3, 4, 5], torch.float32, True), ]) def forward(self, x): return x.repeat_interleave(2) @register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule()) def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) ``` The error log is as follows: ``` Unexpected outcome summary: (onnx) ****** Failed tests - 1 tests FAIL - "RepeatInterleaveSelfIntNoDimModule_basic" @ trace item #0 - call to "forward" @ output of call to "forward" ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120])) ``` @rsuderman Would you please help me check what's wrong with my PR? Thanks a lot.
- Loading branch information
Xinyu Yang
authored
Apr 17, 2024
1 parent
491f482
commit d4313ee
Showing
9 changed files
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters