@@ -5523,6 +5523,192 @@ class DecomposeAtenConvolutionBackwardOp
5523
5523
};
5524
5524
} // namespace
5525
5525
5526
+ class DecomposeAtenNonzeroOp : public OpRewritePattern <AtenNonzeroOp> {
5527
+ using OpRewritePattern::OpRewritePattern;
5528
+ LogicalResult matchAndRewrite (AtenNonzeroOp op,
5529
+ PatternRewriter &rewriter) const override {
5530
+ Location loc = op.getLoc ();
5531
+ auto si64Type = rewriter.getIntegerType (64 , true );
5532
+ Value si64Dtype = getDtypeIntValueForType (rewriter, loc, si64Type);
5533
+ // helper for making int constants
5534
+ std::function<Value (int64_t )> c = [&](int64_t val) {
5535
+ Value newIntConstant =
5536
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (val));
5537
+ return newIntConstant;
5538
+ };
5539
+ std::function<Value (Value)> makeOneElementList = [&](Value element) {
5540
+ auto listType = Torch::ListType::get (element.getType ());
5541
+ return rewriter.create <PrimListConstructOp>(loc, listType,
5542
+ ArrayRef<Value>{element});
5543
+ };
5544
+
5545
+ Value input = op.getSelf ();
5546
+ auto inputType = dyn_cast<BaseTensorType>(input.getType ());
5547
+ int64_t inputRank = inputType.getSizes ().size ();
5548
+
5549
+ // original_shape = t.shape
5550
+ auto shapeType = Torch::ValueTensorType::get (
5551
+ rewriter.getContext (), SmallVector<int64_t >{inputRank}, si64Type);
5552
+ Value inputShapeTensor =
5553
+ rewriter.create <Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);
5554
+
5555
+ // t = flatten(t)
5556
+ int64_t flattenedSize = 1 ;
5557
+ if (inputType.hasSizes ()) {
5558
+ for (auto size : inputType.getSizes ()) {
5559
+ flattenedSize *= size;
5560
+ }
5561
+ } else {
5562
+ flattenedSize = kUnknownSize ;
5563
+ }
5564
+
5565
+ auto flattendInputShape = SmallVector<int64_t >{flattenedSize};
5566
+ auto flattenedInputType = rewriter.getType <Torch::ValueTensorType>(
5567
+ flattendInputShape, inputType.getOptionalDtype ());
5568
+
5569
+ Value inputDimsStart =
5570
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
5571
+ Value inputDimsEnd = rewriter.create <ConstantIntOp>(
5572
+ loc, rewriter.getI64IntegerAttr (inputRank - 1 ));
5573
+
5574
+ Value flattenedInput = rewriter.create <AtenFlattenUsingIntsOp>(
5575
+ loc, flattenedInputType, input, inputDimsStart, inputDimsEnd);
5576
+
5577
+ // nonzero_mask = (t != 0)
5578
+ auto boolMaskType = inputType.getWithSizesAndDtype (
5579
+ flattenedInputType.getOptionalSizes (), rewriter.getI1Type ());
5580
+ Value boolMask = rewriter.create <AtenNeScalarOp>(loc, boolMaskType,
5581
+ flattenedInput, c (0 ));
5582
+
5583
+ // nonzero_mask = nonzero_mask.int()
5584
+ Value falseCst = rewriter.create <ConstantBoolOp>(loc, false );
5585
+ Value noneCst = rewriter.create <ConstantNoneOp>(loc);
5586
+ auto intMaskType = flattenedInputType.getWithSizesAndDtype (
5587
+ flattenedInputType.getOptionalSizes (), si64Type); // ####
5588
+ Value intMask = rewriter.create <AtenToDtypeOp>(
5589
+ loc, intMaskType, boolMask, si64Dtype, falseCst, falseCst, noneCst);
5590
+
5591
+ // destination_indices = torch.cumsum(nonzero_mask, 0) - 1
5592
+ auto cumulativeSumType =
5593
+ dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype (
5594
+ flattenedInputType.getOptionalSizes (), si64Type));
5595
+ Value cumulativeSum = rewriter.create <AtenCumsumOp>(loc, cumulativeSumType,
5596
+ intMask, c (0 ), noneCst);
5597
+ Value one =
5598
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 ));
5599
+ Value subtracted = rewriter.create <AtenSubScalarOp>(
5600
+ loc, cumulativeSumType, cumulativeSum, one, /* alpha=*/ one);
5601
+
5602
+ // destination_indices = torch.clamp(destination_indices, min=0)
5603
+ Value indices = rewriter.create <AtenClampMinOp>(loc, cumulativeSumType,
5604
+ subtracted, c (0 ));
5605
+
5606
+ // iota = torch.tensor(range(len(t))) * nonzero_mask.int()
5607
+ Value rangeTensor = rewriter.create <AtenArangeStartStepOp>(
5608
+ loc, cumulativeSumType, c (0 ),
5609
+ rewriter.create <ConstantIntOp>(
5610
+ loc, rewriter.getI64IntegerAttr (flattenedInputType.getSizes ()[0 ])),
5611
+ one, noneCst, noneCst, noneCst, noneCst);
5612
+ Value multiplied = rewriter.create <AtenMulTensorOp>(loc, cumulativeSumType,
5613
+ rangeTensor, intMask);
5614
+
5615
+ // scatter_self = torch.zeros_like(t, dtype=torch.int64)
5616
+ // AtenFullLike doesn't support index type so we have to use si64
5617
+ auto zerosTensorType = cumulativeSumType.getWithSizesAndDtype (
5618
+ cumulativeSumType.getOptionalSizes (), si64Type);
5619
+ Value zerosTensor = rewriter.create <AtenZerosLikeOp>(
5620
+ loc, zerosTensorType, cumulativeSum, si64Dtype, noneCst, noneCst,
5621
+ noneCst, noneCst);
5622
+
5623
+ // compacted = scatter_self.scatter_(
5624
+ // dim=0,
5625
+ // index=destination_indices,
5626
+ // src=iota, reduce='add')
5627
+ Value reduceStr = rewriter.create <ConstantStrOp>(loc, " sum" );
5628
+ Value constAxis = rewriter.create <Torch::ConstantIntOp>(
5629
+ loc, rewriter.getType <Torch::IntType>(),
5630
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
5631
+
5632
+ Value cstFalse = rewriter.create <ConstantBoolOp>(loc, false );
5633
+ Value scatteredTensor = rewriter.create <AtenScatterReduceTwoOp>(
5634
+ loc, cumulativeSumType, zerosTensor, /* axis=*/ constAxis,
5635
+ /* dims=*/ indices, /* src=*/ multiplied, reduceStr, cstFalse);
5636
+
5637
+ // result_flat = compacted[:torch.sum(nonzero_mask)]
5638
+ auto scalarType = ValueTensorType::get (rewriter.getContext (),
5639
+ ArrayRef<int64_t >{}, si64Type);
5640
+ Value sumMask =
5641
+ rewriter.create <AtenSumOp>(loc, scalarType, intMask, noneCst);
5642
+ Value numNonzero = rewriter.create <AtenIntTensorOp>(loc, sumMask);
5643
+
5644
+ auto slicedResultType = Torch::ValueTensorType::get (
5645
+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize }, si64Type);
5646
+ Value slicedResult =
5647
+ rewriter.create <AtenSliceTensorOp>(loc, slicedResultType,
5648
+ /* self=*/ scatteredTensor,
5649
+ /* dim=*/ c (0 ),
5650
+ /* start=*/ c (0 ),
5651
+ /* end=*/ numNonzero,
5652
+ /* step=*/ one);
5653
+
5654
+ // strides = torch.cumprod(torch.flip(inputShapeTensor, [0]), 0).flip(0)
5655
+ Value flippedShape = rewriter.create <AtenFlipOp>(
5656
+ loc, shapeType, inputShapeTensor, makeOneElementList (c (0 )));
5657
+ Value cumulativeProduct = rewriter.create <AtenCumprodOp>(
5658
+ loc, shapeType, flippedShape, c (0 ), noneCst);
5659
+ Value flippedCumulativeProduct = rewriter.create <AtenFlipOp>(
5660
+ loc, shapeType, cumulativeProduct, makeOneElementList (c (0 )));
5661
+ // strides = torch.cat([strides[1:], torch.tensor([1],
5662
+ // device=t.device)])
5663
+ auto oneTensorType = ValueTensorType::get (
5664
+ rewriter.getContext (), SmallVector<int64_t >{1 }, si64Type);
5665
+ Value oneTensor = rewriter.create <AtenScalarTensorOp>(
5666
+ loc, oneTensorType, c (1 ), si64Dtype, noneCst, noneCst, noneCst);
5667
+
5668
+ auto slicedStrideType = Torch::ValueTensorType::get (
5669
+ rewriter.getContext (), SmallVector<int64_t >{inputRank - 1 }, // sizes
5670
+ si64Type);
5671
+ Value strideSliceStart = c (1 );
5672
+ Value strideSliceEnd = c (inputRank);
5673
+ Value slicedStrides = rewriter.create <AtenSliceTensorOp>(
5674
+ loc, slicedStrideType, flippedCumulativeProduct, /* dim*/ c (0 ),
5675
+ /* start=*/ strideSliceStart, /* end=*/ strideSliceEnd, /* step=*/ c (1 ));
5676
+
5677
+ auto tensorListElementType = Torch::ValueTensorType::get (
5678
+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize }, si64Type);
5679
+ Value tensorList = rewriter.create <Torch::PrimListConstructOp>(
5680
+ loc, Torch::ListType::get (tensorListElementType),
5681
+ SmallVector<Value>{slicedStrides, oneTensor});
5682
+ Value strides =
5683
+ rewriter.create <Torch::AtenCatOp>(loc, shapeType, tensorList, c (0 ));
5684
+
5685
+ // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
5686
+ // inputShapeTensor
5687
+ auto unsqueezedResultType = ValueTensorType::get (
5688
+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize , 1 }, si64Type);
5689
+ Value unsqueezedResult = rewriter.create <AtenUnsqueezeOp>(
5690
+ loc, unsqueezedResultType, slicedResult, c (1 ));
5691
+
5692
+ auto unsqueezedStridesType = ValueTensorType::get (
5693
+ rewriter.getContext (), SmallVector<int64_t >{1 , inputRank}, si64Type);
5694
+ Value unsqueezedStrides = rewriter.create <AtenUnsqueezeOp>(
5695
+ loc, unsqueezedStridesType, strides, c (0 ));
5696
+
5697
+ auto dividedBroadcastType = ValueTensorType::get (
5698
+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize , inputRank},
5699
+ si64Type);
5700
+ Value divided = rewriter.create <AtenFloorDivideOp>(
5701
+ loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);
5702
+
5703
+ auto resultType = cast<BaseTensorType>(op.getType ());
5704
+ Value modded = rewriter.create <AtenRemainderTensorOp>(
5705
+ loc, resultType, divided, inputShapeTensor);
5706
+
5707
+ rewriter.replaceOp (op, modded);
5708
+ return success ();
5709
+ }
5710
+ };
5711
+
5526
5712
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
5527
5713
namespace {
5528
5714
class DecomposeAtenAddmmOp : public OpRewritePattern <AtenAddmmOp> {
@@ -10573,6 +10759,7 @@ class DecomposeComplexOpsPass
10573
10759
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
10574
10760
patterns);
10575
10761
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
10762
+ addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
10576
10763
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
10577
10764
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
10578
10765
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
0 commit comments