Skip to content

Commit 9938abf

Browse files
authored
AtenCumprodOp (#3737)
1 parent 335cf5f commit 9938abf

File tree

7 files changed

+229
-0
lines changed

7 files changed

+229
-0
lines changed

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
4040

4141
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
4242
Type elemTy);
43+
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
44+
Type elemTy);
4345

4446
Value castIntToIndex(OpBuilder &b, Location loc, Value v);
4547

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,79 @@ class ConvertAtenSortOp : public OpConversionPattern<AtenSortOp> {
14971497
};
14981498
} // namespace
14991499

1500+
namespace {
1501+
class ConvertAtenCumprodOp : public OpConversionPattern<AtenCumprodOp> {
1502+
public:
1503+
using OpConversionPattern::OpConversionPattern;
1504+
LogicalResult
1505+
matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor,
1506+
ConversionPatternRewriter &rewriter) const override {
1507+
1508+
Location loc = op.getLoc();
1509+
Value input = adaptor.getSelf();
1510+
auto resultType = cast<RankedTensorType>(
1511+
getTypeConverter()->convertType(op->getResult(0).getType()));
1512+
Type elementType = resultType.getElementType();
1513+
Type inputElementType =
1514+
cast<RankedTensorType>(input.getType()).getElementType();
1515+
1516+
// Converting the input element type to the result's element type.
1517+
// The only possible mismatch would be when the input element type is an
1518+
// integer but not `si64`. Therefore, we directly convert the input to
1519+
// `si64`. Rest all cases are handled in the dtype definition for this op.
1520+
if (elementType != inputElementType) {
1521+
Value torchInput = convertTensorToDtype(
1522+
rewriter, loc, op.getSelf(),
1523+
rewriter.getIntegerType(64, IntegerType::Signed));
1524+
input = typeConverter->materializeTargetConversion(
1525+
rewriter, loc, typeConverter->convertType(torchInput.getType()),
1526+
torchInput);
1527+
}
1528+
1529+
int64_t inputRank = resultType.getRank();
1530+
Value dtype = op.getDtype();
1531+
if (!isa<Torch::NoneType>(dtype.getType()))
1532+
return rewriter.notifyMatchFailure(
1533+
op, "unsupported: dtype argument not supported");
1534+
1535+
int64_t dim;
1536+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
1537+
return rewriter.notifyMatchFailure(
1538+
op, "unimplemented: only constant dim value is supported");
1539+
dim = toPositiveDim(dim, inputRank);
1540+
if (!isValidDim(dim, inputRank))
1541+
return rewriter.notifyMatchFailure(op, "invalid dim");
1542+
1543+
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, input);
1544+
Value output = createOneInitTensor(rewriter, loc, sizes, elementType);
1545+
output = rewriter.create<tensor::CastOp>(loc, resultType, output);
1546+
1547+
SmallVector<Value> accSizes(sizes);
1548+
accSizes.erase(accSizes.begin() + dim);
1549+
SmallVector<int64_t> accStatic(
1550+
makeShapeTorchCompatible(resultType.getShape()));
1551+
accStatic.erase(accStatic.begin() + dim);
1552+
Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType);
1553+
Type accType =
1554+
RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
1555+
acc = rewriter.create<tensor::CastOp>(loc, accType, acc);
1556+
1557+
Value result = createTMTensorScanOp(
1558+
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
1559+
[](OpBuilder &b, Location loc, Value input, Value acc) {
1560+
Value prod =
1561+
(isa<mlir::FloatType>(input.getType())
1562+
? b.create<arith::MulFOp>(loc, input, acc)->getResult(0)
1563+
: b.create<arith::MulIOp>(loc, input, acc)->getResult(0));
1564+
b.create<TMTensor::YieldOp>(loc, prod);
1565+
});
1566+
1567+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
1568+
return success();
1569+
}
1570+
};
1571+
} // namespace
1572+
15001573
namespace {
15011574
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
15021575
public:
@@ -2240,6 +2313,8 @@ class ConvertTorchToTMTensor
22402313
patterns.add<ConvertAtenSortOp>(typeConverter, context);
22412314
target.addIllegalOp<AtenCumsumOp>();
22422315
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
2316+
target.addIllegalOp<AtenCumprodOp>();
2317+
patterns.add<ConvertAtenCumprodOp>(typeConverter, context);
22432318
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
22442319
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
22452320
context);

lib/Conversion/Utils/Utils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
138138
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
139139
}
140140

141+
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
142+
Type elemTy) {
143+
Value initTensor =
144+
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
145+
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
146+
Value c1 =
147+
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));
148+
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
149+
}
150+
141151
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
142152
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
143153
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9134,6 +9134,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
91349134
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
91359135
" return %arg0 : !torch.list<int>\n"
91369136
" }\n"
9137+
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
9138+
" return %arg0 : !torch.list<int>\n"
9139+
" }\n"
91379140
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
91389141
" return %arg0 : !torch.list<int>\n"
91399142
" }\n"
@@ -11844,6 +11847,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1184411847
" }\n"
1184511848
" return %1 : !torch.int\n"
1184611849
" }\n"
11850+
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
11851+
" %int4 = torch.constant.int 4\n"
11852+
" %none = torch.constant.none\n"
11853+
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
11854+
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
11855+
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
11856+
" torch.prim.If.yield %2 : !torch.int\n"
11857+
" } else {\n"
11858+
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11859+
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
11860+
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
11861+
" torch.prim.If.yield %int4 : !torch.int\n"
11862+
" } else {\n"
11863+
" torch.prim.If.yield %2#1 : !torch.int\n"
11864+
" }\n"
11865+
" torch.prim.If.yield %4 : !torch.int\n"
11866+
" }\n"
11867+
" return %1 : !torch.int\n"
11868+
" }\n"
1184711869
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1184811870
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1184911871
" return %0#1 : !torch.int\n"

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#### General TorchDynamo/PyTorch errors
8080
# torch._dynamo.exc.Unsupported: Tensor.item
8181
"CumsumModule_basic",
82+
"CumprodModule_basic",
8283
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
8384
# RuntimeError: Failed running call_function aten.convolution_backward(...
8485
# https://github.com/pytorch/pytorch/issues/89629
@@ -432,6 +433,7 @@
432433
"ConvolutionBackwardModule2DStrided_basic",
433434
"ConvolutionBackwardModule2D_basic",
434435
"CumsumModule_basic",
436+
"CumprodModule_basic",
435437
"DeformConv2D_basic",
436438
"DivFloatModule_basic",
437439
"DivIntModule_basic",
@@ -667,6 +669,10 @@
667669
"ConvolutionBackwardModule2DStrided_basic",
668670
"ConvolutionBackwardModule2D_basic",
669671
"CumsumModule_basic",
672+
"CumprodModule_basic",
673+
"CumprodInputDtypeInt32Module_basic",
674+
"CumprodStaticModule_basic",
675+
"CumprodStaticNegativeDimModule_basic",
670676
"DeformConv2D_basic",
671677
"DeterminantBatchedModule_F32",
672678
"DeterminantDynamicModule_F32",
@@ -1077,6 +1083,9 @@
10771083
"CumsumInputDtypeInt32Module_basic",
10781084
"CumsumStaticModule_basic",
10791085
"CumsumStaticNegativeDimModule_basic",
1086+
"CumprodInputDtypeInt32Module_basic",
1087+
"CumprodStaticModule_basic",
1088+
"CumprodStaticNegativeDimModule_basic",
10801089
"DetachModule_basic",
10811090
"DivFloatModule_basic",
10821091
"DivIntModule_basic",
@@ -3105,6 +3114,10 @@
31053114
"CopyWithDifferentDTypesModule_basic",
31063115
"CosineSimilarityStaticBroadcastModule_basic",
31073116
"CumsumInputDtypeInt32Module_basic",
3117+
"CumprodModule_basic",
3118+
"CumprodInputDtypeInt32Module_basic",
3119+
"CumprodStaticModule_basic",
3120+
"CumprodStaticNegativeDimModule_basic",
31083121
"ElementwiseAcosIntModule_basic",
31093122
"ElementwiseAsinIntModule_basic",
31103123
"ElementwiseAtanTensorIntModule_basic",
@@ -3378,6 +3391,10 @@
33783391
"CumsumModule_basic",
33793392
"CumsumStaticModule_basic",
33803393
"CumsumStaticNegativeDimModule_basic",
3394+
"CumprodModule_basic",
3395+
"CumprodInputDtypeInt32Module_basic",
3396+
"CumprodStaticModule_basic",
3397+
"CumprodStaticNegativeDimModule_basic",
33813398
"DeformConv2D_basic",
33823399
"DeterminantBatchedModule_F32",
33833400
"DeterminantDynamicModule_F32",
@@ -4110,6 +4127,10 @@
41104127
"CumsumModule_basic",
41114128
"CumsumStaticModule_basic",
41124129
"CumsumStaticNegativeDimModule_basic",
4130+
"CumprodModule_basic",
4131+
"CumprodInputDtypeInt32Module_basic",
4132+
"CumprodStaticModule_basic",
4133+
"CumprodStaticNegativeDimModule_basic",
41134134
"DeformConv2D_basic",
41144135
"DeterminantModule_F32",
41154136
"DeterminantBatchedModule_F32",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,9 @@ def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: b
14341434
def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
14351435
return self
14361436

1437+
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
1438+
return self
1439+
14371440
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
14381441
return self
14391442

@@ -2926,6 +2929,18 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
29262929
return torch.int64
29272930
return self_dtype
29282931

2932+
2933+
@check_dtype_function(
2934+
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
2935+
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32))
2936+
def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
2937+
if dtype is not None:
2938+
return dtype
2939+
self_rank, self_dtype = self_rank_dtype
2940+
if is_integer_dtype(self_dtype):
2941+
return torch.int64
2942+
return self_dtype
2943+
29292944
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
29302945
def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
29312946
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4830,6 +4830,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils):
48304830
# ==============================================================================
48314831

48324832

4833+
class CumprodModule(torch.nn.Module):
4834+
def __init__(self):
4835+
super().__init__()
4836+
4837+
@export
4838+
@annotate_args(
4839+
[
4840+
None,
4841+
([-1, -1, -1], torch.float32, True),
4842+
]
4843+
)
4844+
def forward(self, val):
4845+
ones = torch.ones([1], dtype=torch.int32)
4846+
return torch.ops.aten.cumprod(val, ones.item())
4847+
4848+
4849+
@register_test_case(module_factory=lambda: CumprodModule())
4850+
def CumprodModule_basic(module, tu: TestUtils):
4851+
module.forward(tu.rand(2, 7, 4))
4852+
4853+
4854+
class CumprodStaticModule(torch.nn.Module):
4855+
def __init__(self):
4856+
super().__init__()
4857+
4858+
@export
4859+
@annotate_args(
4860+
[
4861+
None,
4862+
([2, 7, 4], torch.float32, True),
4863+
]
4864+
)
4865+
def forward(self, val):
4866+
return torch.ops.aten.cumprod(val, 1)
4867+
4868+
4869+
@register_test_case(module_factory=lambda: CumprodStaticModule())
4870+
def CumprodStaticModule_basic(module, tu: TestUtils):
4871+
module.forward(tu.rand(2, 7, 4))
4872+
4873+
4874+
class CumprodStaticNegativeDimModule(torch.nn.Module):
4875+
def __init__(self):
4876+
super().__init__()
4877+
4878+
@export
4879+
@annotate_args(
4880+
[
4881+
None,
4882+
([2, 7, 4], torch.float32, True),
4883+
]
4884+
)
4885+
def forward(self, val):
4886+
return torch.ops.aten.cumprod(val, dim=-1)
4887+
4888+
4889+
@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule())
4890+
def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils):
4891+
module.forward(tu.rand(2, 7, 4))
4892+
4893+
4894+
class CumprodInputDtypeInt32Module(torch.nn.Module):
4895+
def __init__(self):
4896+
super().__init__()
4897+
4898+
@export
4899+
@annotate_args(
4900+
[
4901+
None,
4902+
([2, 7, 4], torch.int32, True),
4903+
]
4904+
)
4905+
def forward(self, val):
4906+
return torch.ops.aten.cumprod(val, 1)
4907+
4908+
4909+
@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module())
4910+
def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils):
4911+
module.forward(tu.randint(2, 7, 4).to(torch.int32))
4912+
4913+
4914+
# ==============================================================================
4915+
4916+
48334917
class AtenToDeviceModule(torch.nn.Module):
48344918
def __init__(self):
48354919
super().__init__()

0 commit comments

Comments
 (0)