Skip to content

Commit

Permalink
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (llvm#3075)
Browse files Browse the repository at this point in the history
Decomposition RepeatInterleaveSelfInt with following ops:
```python

def my_repeat_interleave(input, repeats, dim=None):
    if dim is None:
        # Flatten the input and then repeat
        return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
    else:
        # Calculate the shape after repeat
        expanded_shape = list(input.shape)
        expanded_shape[dim] *= repeats
        # Repeat the tensor along the specified dimension
        repeat_shape = [1] * (input.dim() + 1)
        repeat_shape[dim + 1] = repeats
        input = input.unsqueeze(-1)

        # Tile and then reshape
        tiled = torch.tile(input, repeat_shape)
        # Rearrange and reshape
        repeated = tiled.reshape(*expanded_shape)
    return repeated

```

I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI  **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([3, 4, 5], torch.float32, True),
    ])
    def forward(self, x):
        return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
  Unexpected outcome summary: (onnx)
  
  ****** Failed tests - 1 tests
      FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
          @ trace item #0 - call to "forward"
          @ output of call to "forward"
          ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```

@rsuderman 
Would you please help me check what's wrong with my PR? Thanks a lot.
  • Loading branch information
Xinyu Yang authored Apr 17, 2024
1 parent 491f482 commit d4313ee
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 0 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10418,6 +10418,32 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
}];
}

def Torch_AtenRepeatInterleaveSelfIntOp : Torch_Op<"aten.repeat_interleave.self_int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$repeats,
AnyTorchOptionalIntType:$dim,
AnyTorchOptionalIntType:$output_size
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRepeatInterleaveSelfIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRepeatInterleaveSelfIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenTileOp : Torch_Op<"aten.tile", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
54 changes: 54 additions & 0 deletions lib/Conversion/TorchToStablehlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,59 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
PrimsCollapseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
if (!selfType) {
return op.emitError("only tensor types are currently supported");
}

auto rank = selfType.getRank();
if (rank == 0)
return rewriter.notifyMatchFailure(
op, "the rank of tensor must be greater than 0");

int64_t start, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(
op, "only constant start is currently supported");
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(
op, "only constant end is currently supported");

start = toPositiveDim(start, rank);
end = toPositiveDim(end, rank);
SmallVector<int64_t, 4> dims;
dims.reserve(rank);
for (int r = 0; r < start; ++r)
dims.push_back(r);
int64_t collapsedDimSize = 1;
for (int r = start; r <= end; ++r) {
if (selfType.getShape()[r] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
op, "the size of the dimension being collapsed is can't be unknown");
collapsedDimSize *= selfType.getShape()[r];
}
dims.push_back(collapsedDimSize);
for (int r = end + 1; r < rank; ++r)
dims.push_back(r);

auto newDimSizesInfo = hlo::getDimSizesOfTensor(
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
stablehloShape);
return success();
}

void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
Expand All @@ -405,6 +458,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_VIEW_OP_PATTERN(AtenOp) \
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7331,6 +7331,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
" %2 = func.call @__torch__.torch.jit._shape_functions.flatten(%arg0, %int0, %int-1) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = torch.aten.__getitem__.t %2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.aten.mul.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %5 : !torch.list<int>\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" %3 = torch.aten.slice.t %arg0, %none, %2, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.mul.int %4, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %6 = torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>\n"
" %7 = torch.aten.add.t %3, %6 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %8 = torch.aten.add.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %9 = torch.aten.slice.t %arg0, %8, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %10 = torch.aten.add.t %7, %9 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" torch.prim.If.yield %10 : !torch.list<int>\n"
" }\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.tile\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
Expand Down Expand Up @@ -10429,6 +10455,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tile\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
96 changes: 96 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2800,6 +2800,100 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
};
} // namespace

// decompose aten.repeat_interleave.self_int into following ops:
// aten.flatten.using_ints, aten.unsqueeze, aten.tile, aten.reshape
namespace {

class DecomposeAtenRepeatInterleaveSelfIntOp
: public OpRewritePattern<AtenRepeatInterleaveSelfIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRepeatInterleaveSelfIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
Value self = op.getSelf();
auto selfTy = cast<BaseTensorType>(self.getType());
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasSizes())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");

int64_t inputRank = selfTy.getSizes().size();
int64_t repeats;
if (!matchPattern(op.getRepeats(), m_TorchConstantInt(&repeats)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: repeats not constant int");

bool dimIsNone = false;
int64_t dim;
Value dimValue = op.getDim();
if (dimValue.getType().isa<Torch::NoneType>()) {
dimIsNone = true;
dim = inputRank - 1;
} else {
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: dim not constant int");
dim = toPositiveDim(dim, inputRank);
}

dimValue =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
Value dimValuePlusOne = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim + 1));

auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne);
if (failed(unsqueezedInfo))
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor op");
self = *unsqueezedInfo;

Value constMinusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
SmallVector<Value> expandShapeValueList(inputRank + 1, constMinusOne);
expandShapeValueList[dim + 1] = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(repeats));
Value expandShapeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), expandShapeValueList);
Value constFalse =
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));

SmallVector<int64_t> expandShape(inputRank + 1);
for (int64_t i = 0; i <= dim; i++) {
expandShape[i] = selfTy.getSizes()[i];
}
expandShape[dim + 1] = repeats;
for (int64_t i = dim + 1; i < inputRank; i++) {
expandShape[i + 1] = selfTy.getSizes()[i];
}

BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
expandShape, selfTy.getOptionalDtype());

Value expandSelf = rewriter.create<AtenExpandOp>(
loc, expandTy, self, expandShapeList, constFalse);

Value result;
if (dimIsNone) {
Value constZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
result = rewriter.create<AtenFlattenUsingIntsOp>(
loc, resType, expandSelf, constZero, constMinusOne);
} else {
result = rewriter.create<PrimsCollapseOp>(loc, resType, expandSelf,
dimValue, dimValuePlusOne);
}

rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

// Decompose aten.flatten.using_ints into aten.view op.
namespace {
class DecomposeAtenFlattenUsingIntsOp
Expand Down Expand Up @@ -7465,6 +7559,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatInterleaveSelfIntOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenUnflattenIntOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
target.addIllegalOp<AtenExpandOp>();
target.addIllegalOp<AtenFlattenUsingIntsOp>();
target.addIllegalOp<AtenWhereScalarOp>();
Expand Down
12 changes: 12 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"CloneModule_basic",
"CollapseAllDimensionsModule_basic",
"CollapseStaticModule_basic",
"ConstantBoolParameterModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
Expand Down Expand Up @@ -853,6 +855,8 @@
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"RollModule_basic",
Expand Down Expand Up @@ -1390,6 +1394,7 @@
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"ResNet18StaticModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
Expand Down Expand Up @@ -1512,6 +1517,7 @@
"TensorIntModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
"ViewSizeDimFollowedByExpandedOnesModule_basic",
Expand Down Expand Up @@ -2352,6 +2358,12 @@
"ReduceL1NormWithDTypeModule_basic",
}

if torch_version_for_comparison() < version.parse('2.3.0.dev'):
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
"RepeatInterleaveSelfIntNoDimModule_basic",
}


ONNX_CRASHING_SET = {
"FakeQuantizePerTensorAffineModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,15 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
out.append(self[i] * repeats[i + leading_rank])
return out

def aten〇repeat_interleave〇self_int〡shape(self: List[int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> List[int]:
if dim is None:
flatten_size = upstream_shape_functions.flatten(self, 0, -1)[0]
return [flatten_size * repeats]
else:
out = self[:dim] + [self[dim] * repeats] + self[dim + 1:]
return out


@check_shape_function([
Invocation(TensorOfShape(3, 2, 8), [2, 2]), # dims_length < self_length
Invocation(TensorOfShape(3, 2, 8), [2, 2, 2]) # dims_length >= self_length
Expand Down Expand Up @@ -2625,6 +2634,11 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int])
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=1))
def aten〇repeat_interleave〇self_int〡dtype(self_rank_dtype: Tuple[int, int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[1]))
def aten〇tile〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)")
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")
Expand Down
41 changes: 41 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,47 @@ def RepeatModule_basic(module, tu: TestUtils):
# ==============================================================================


class RepeatInterleaveSelfIntModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2, 1)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntModule())
def RepeatInterleaveSelfIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================


class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class TileSmallDimsSizeModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit d4313ee

Please sign in to comment.