@@ -1497,6 +1497,79 @@ class ConvertAtenSortOp : public OpConversionPattern<AtenSortOp> {
1497
1497
};
1498
1498
} // namespace
1499
1499
1500
+ namespace {
1501
+ class ConvertAtenCumprodOp : public OpConversionPattern <AtenCumprodOp> {
1502
+ public:
1503
+ using OpConversionPattern::OpConversionPattern;
1504
+ LogicalResult
1505
+ matchAndRewrite (AtenCumprodOp op, OpAdaptor adaptor,
1506
+ ConversionPatternRewriter &rewriter) const override {
1507
+
1508
+ Location loc = op.getLoc ();
1509
+ Value input = adaptor.getSelf ();
1510
+ auto resultType = cast<RankedTensorType>(
1511
+ getTypeConverter ()->convertType (op->getResult (0 ).getType ()));
1512
+ Type elementType = resultType.getElementType ();
1513
+ Type inputElementType =
1514
+ cast<RankedTensorType>(input.getType ()).getElementType ();
1515
+
1516
+ // Converting the input element type to the result's element type.
1517
+ // The only possible mismatch would be when the input element type is an
1518
+ // integer but not `si64`. Therefore, we directly convert the input to
1519
+ // `si64`. Rest all cases are handled in the dtype definition for this op.
1520
+ if (elementType != inputElementType) {
1521
+ Value torchInput = convertTensorToDtype (
1522
+ rewriter, loc, op.getSelf (),
1523
+ rewriter.getIntegerType (64 , IntegerType::Signed));
1524
+ input = typeConverter->materializeTargetConversion (
1525
+ rewriter, loc, typeConverter->convertType (torchInput.getType ()),
1526
+ torchInput);
1527
+ }
1528
+
1529
+ int64_t inputRank = resultType.getRank ();
1530
+ Value dtype = op.getDtype ();
1531
+ if (!isa<Torch::NoneType>(dtype.getType ()))
1532
+ return rewriter.notifyMatchFailure (
1533
+ op, " unsupported: dtype argument not supported" );
1534
+
1535
+ int64_t dim;
1536
+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
1537
+ return rewriter.notifyMatchFailure (
1538
+ op, " unimplemented: only constant dim value is supported" );
1539
+ dim = toPositiveDim (dim, inputRank);
1540
+ if (!isValidDim (dim, inputRank))
1541
+ return rewriter.notifyMatchFailure (op, " invalid dim" );
1542
+
1543
+ SmallVector<Value> sizes = getTensorSizes (rewriter, loc, input);
1544
+ Value output = createOneInitTensor (rewriter, loc, sizes, elementType);
1545
+ output = rewriter.create <tensor::CastOp>(loc, resultType, output);
1546
+
1547
+ SmallVector<Value> accSizes (sizes);
1548
+ accSizes.erase (accSizes.begin () + dim);
1549
+ SmallVector<int64_t > accStatic (
1550
+ makeShapeTorchCompatible (resultType.getShape ()));
1551
+ accStatic.erase (accStatic.begin () + dim);
1552
+ Value acc = createOneInitTensor (rewriter, loc, accSizes, elementType);
1553
+ Type accType =
1554
+ RankedTensorType::get (makeShapeLLVMCompatible (accStatic), elementType);
1555
+ acc = rewriter.create <tensor::CastOp>(loc, accType, acc);
1556
+
1557
+ Value result = createTMTensorScanOp (
1558
+ rewriter, loc, input, output, acc, dim, /* inclusive=*/ true ,
1559
+ [](OpBuilder &b, Location loc, Value input, Value acc) {
1560
+ Value prod =
1561
+ (isa<mlir::FloatType>(input.getType ())
1562
+ ? b.create <arith::MulFOp>(loc, input, acc)->getResult (0 )
1563
+ : b.create <arith::MulIOp>(loc, input, acc)->getResult (0 ));
1564
+ b.create <TMTensor::YieldOp>(loc, prod);
1565
+ });
1566
+
1567
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
1568
+ return success ();
1569
+ }
1570
+ };
1571
+ } // namespace
1572
+
1500
1573
namespace {
1501
1574
class ConvertAtenCumsumOp : public OpConversionPattern <AtenCumsumOp> {
1502
1575
public:
@@ -2240,6 +2313,8 @@ class ConvertTorchToTMTensor
2240
2313
patterns.add <ConvertAtenSortOp>(typeConverter, context);
2241
2314
target.addIllegalOp <AtenCumsumOp>();
2242
2315
patterns.add <ConvertAtenCumsumOp>(typeConverter, context);
2316
+ target.addIllegalOp <AtenCumprodOp>();
2317
+ patterns.add <ConvertAtenCumprodOp>(typeConverter, context);
2243
2318
target.addIllegalOp <AtenScaledDotProductAttentionOp>();
2244
2319
patterns.add <ConvertAtenScaledDotProductAttentionOp>(typeConverter,
2245
2320
context);
0 commit comments