Skip to content

Commit

Permalink
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (llvm#3140)
Browse files Browse the repository at this point in the history
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
    def forward(self, x):
        if len(x.shape) == 5:
            b0, t, c0, h0, w0 = x.shape
            b, c, h, w = torch.mul(b0, t), c0, h0, w0
        else:
            b1, c1, h1, w1 = x.shape
            b, c, h, w = b1, c1, h1, w1
        res = torch.reshape(x, [b, c, h, w])
        return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
  • Loading branch information
Xinyu Yang authored Apr 11, 2024
1 parent 6524838 commit 308c45e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3088,6 +3088,9 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
if (!listConstruct)
return failure();

if (op->getNumResults() != listConstruct.getElements().size())
return failure();

rewriter.replaceOp(op, listConstruct.getElements());
return success();
});
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PowIntFloatModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
Expand Down Expand Up @@ -1216,6 +1217,7 @@
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
Expand Down Expand Up @@ -1391,6 +1393,7 @@
"ElementwisePreluStaticModule_basic",

# Shape Related failures
"PrimListUnpackNumMismatchModule_basic",
"ReshapeExpandModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,33 @@ def SliceCopyNonZeroDim_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(10, 2, 4))


# ==============================================================================
class PrimListUnpackNumMismatchModule(torch.nn.Module):
def __init__(self):
super().__init__()


@export
@annotate_args([
None,
([5, 4, 3, 2, 1], torch.float32, True),
])
def forward(self, x):
if len(x.shape) == 5:
b0, t, c0, h0, w0 = x.shape
b, c, h, w = torch.mul(b0, t), c0, h0, w0
else:
b1, c1, h1, w1 = x.shape
b, c, h, w = b1, c1, h1, w1
res = torch.reshape(x, [b, c, h, w])
return res


@register_test_case(module_factory=lambda: PrimListUnpackNumMismatchModule())
def PrimListUnpackNumMismatchModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2, 1))


# ==============================================================================


Expand Down

0 comments on commit 308c45e

Please sign in to comment.