Skip to content

Commit 35e20e0

Browse files
committed
[Torch] Add decompose for 1d torch.nonzero
1 parent 06d1789 commit 35e20e0

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5523,6 +5523,192 @@ class DecomposeAtenConvolutionBackwardOp
55235523
};
55245524
} // namespace
55255525

5526+
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
5527+
using OpRewritePattern::OpRewritePattern;
5528+
LogicalResult matchAndRewrite(AtenNonzeroOp op,
5529+
PatternRewriter &rewriter) const override {
5530+
Location loc = op.getLoc();
5531+
auto si64Type = rewriter.getIntegerType(64, true);
5532+
Value si64Dtype = getDtypeIntValueForType(rewriter, loc, si64Type);
5533+
// helper for making int constants
5534+
std::function<Value(int64_t)> c = [&](int64_t val) {
5535+
Value newIntConstant =
5536+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(val));
5537+
return newIntConstant;
5538+
};
5539+
std::function<Value(Value)> makeOneElementList = [&](Value element) {
5540+
auto listType = Torch::ListType::get(element.getType());
5541+
return rewriter.create<PrimListConstructOp>(loc, listType,
5542+
ArrayRef<Value>{element});
5543+
};
5544+
5545+
Value input = op.getSelf();
5546+
auto inputType = dyn_cast<BaseTensorType>(input.getType());
5547+
int64_t inputRank = inputType.getSizes().size();
5548+
5549+
// original_shape = t.shape
5550+
auto shapeType = Torch::ValueTensorType::get(
5551+
rewriter.getContext(), SmallVector<int64_t>{inputRank}, si64Type);
5552+
Value inputShapeTensor =
5553+
rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);
5554+
5555+
// t = flatten(t)
5556+
int64_t flattenedSize = 1;
5557+
if (inputType.hasSizes()) {
5558+
for (auto size : inputType.getSizes()) {
5559+
flattenedSize *= size;
5560+
}
5561+
} else {
5562+
flattenedSize = kUnknownSize;
5563+
}
5564+
5565+
auto flattendInputShape = SmallVector<int64_t>{flattenedSize};
5566+
auto flattenedInputType = rewriter.getType<Torch::ValueTensorType>(
5567+
flattendInputShape, inputType.getOptionalDtype());
5568+
5569+
Value inputDimsStart =
5570+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
5571+
Value inputDimsEnd = rewriter.create<ConstantIntOp>(
5572+
loc, rewriter.getI64IntegerAttr(inputRank - 1));
5573+
5574+
Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
5575+
loc, flattenedInputType, input, inputDimsStart, inputDimsEnd);
5576+
5577+
// nonzero_mask = (t != 0)
5578+
auto boolMaskType = inputType.getWithSizesAndDtype(
5579+
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
5580+
Value boolMask = rewriter.create<AtenNeScalarOp>(loc, boolMaskType,
5581+
flattenedInput, c(0));
5582+
5583+
// nonzero_mask = nonzero_mask.int()
5584+
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
5585+
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
5586+
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
5587+
flattenedInputType.getOptionalSizes(), si64Type); // ####
5588+
Value intMask = rewriter.create<AtenToDtypeOp>(
5589+
loc, intMaskType, boolMask, si64Dtype, falseCst, falseCst, noneCst);
5590+
5591+
// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
5592+
auto cumulativeSumType =
5593+
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
5594+
flattenedInputType.getOptionalSizes(), si64Type));
5595+
Value cumulativeSum = rewriter.create<AtenCumsumOp>(loc, cumulativeSumType,
5596+
intMask, c(0), noneCst);
5597+
Value one =
5598+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
5599+
Value subtracted = rewriter.create<AtenSubScalarOp>(
5600+
loc, cumulativeSumType, cumulativeSum, one, /*alpha=*/one);
5601+
5602+
// destination_indices = torch.clamp(destination_indices, min=0)
5603+
Value indices = rewriter.create<AtenClampMinOp>(loc, cumulativeSumType,
5604+
subtracted, c(0));
5605+
5606+
// iota = torch.tensor(range(len(t))) * nonzero_mask.int()
5607+
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
5608+
loc, cumulativeSumType, c(0),
5609+
rewriter.create<ConstantIntOp>(
5610+
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
5611+
one, noneCst, noneCst, noneCst, noneCst);
5612+
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
5613+
rangeTensor, intMask);
5614+
5615+
// scatter_self = torch.zeros_like(t, dtype=torch.int64)
5616+
// AtenFullLike doesn't support index type so we have to use si64
5617+
auto zerosTensorType = cumulativeSumType.getWithSizesAndDtype(
5618+
cumulativeSumType.getOptionalSizes(), si64Type);
5619+
Value zerosTensor = rewriter.create<AtenZerosLikeOp>(
5620+
loc, zerosTensorType, cumulativeSum, si64Dtype, noneCst, noneCst,
5621+
noneCst, noneCst);
5622+
5623+
// compacted = scatter_self.scatter_(
5624+
// dim=0,
5625+
// index=destination_indices,
5626+
// src=iota, reduce='add')
5627+
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "sum");
5628+
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
5629+
loc, rewriter.getType<Torch::IntType>(),
5630+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
5631+
5632+
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
5633+
Value scatteredTensor = rewriter.create<AtenScatterReduceTwoOp>(
5634+
loc, cumulativeSumType, zerosTensor, /*axis=*/constAxis,
5635+
/*dims=*/indices, /*src=*/multiplied, reduceStr, cstFalse);
5636+
5637+
// result_flat = compacted[:torch.sum(nonzero_mask)]
5638+
auto scalarType = ValueTensorType::get(rewriter.getContext(),
5639+
ArrayRef<int64_t>{}, si64Type);
5640+
Value sumMask =
5641+
rewriter.create<AtenSumOp>(loc, scalarType, intMask, noneCst);
5642+
Value numNonzero = rewriter.create<AtenIntTensorOp>(loc, sumMask);
5643+
5644+
auto slicedResultType = Torch::ValueTensorType::get(
5645+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
5646+
Value slicedResult =
5647+
rewriter.create<AtenSliceTensorOp>(loc, slicedResultType,
5648+
/*self=*/scatteredTensor,
5649+
/*dim=*/c(0),
5650+
/*start=*/c(0),
5651+
/*end=*/numNonzero,
5652+
/*step=*/one);
5653+
5654+
// strides = torch.cumprod(torch.flip(inputShapeTensor, [0]), 0).flip(0)
5655+
Value flippedShape = rewriter.create<AtenFlipOp>(
5656+
loc, shapeType, inputShapeTensor, makeOneElementList(c(0)));
5657+
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
5658+
loc, shapeType, flippedShape, c(0), noneCst);
5659+
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
5660+
loc, shapeType, cumulativeProduct, makeOneElementList(c(0)));
5661+
// strides = torch.cat([strides[1:], torch.tensor([1],
5662+
// device=t.device)])
5663+
auto oneTensorType = ValueTensorType::get(
5664+
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
5665+
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
5666+
loc, oneTensorType, c(1), si64Dtype, noneCst, noneCst, noneCst);
5667+
5668+
auto slicedStrideType = Torch::ValueTensorType::get(
5669+
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
5670+
si64Type);
5671+
Value strideSliceStart = c(1);
5672+
Value strideSliceEnd = c(inputRank);
5673+
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
5674+
loc, slicedStrideType, flippedCumulativeProduct, /*dim*/ c(0),
5675+
/*start=*/strideSliceStart, /*end=*/strideSliceEnd, /*step=*/c(1));
5676+
5677+
auto tensorListElementType = Torch::ValueTensorType::get(
5678+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
5679+
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
5680+
loc, Torch::ListType::get(tensorListElementType),
5681+
SmallVector<Value>{slicedStrides, oneTensor});
5682+
Value strides =
5683+
rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList, c(0));
5684+
5685+
// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
5686+
// inputShapeTensor
5687+
auto unsqueezedResultType = ValueTensorType::get(
5688+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, si64Type);
5689+
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
5690+
loc, unsqueezedResultType, slicedResult, c(1));
5691+
5692+
auto unsqueezedStridesType = ValueTensorType::get(
5693+
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, si64Type);
5694+
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
5695+
loc, unsqueezedStridesType, strides, c(0));
5696+
5697+
auto dividedBroadcastType = ValueTensorType::get(
5698+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
5699+
si64Type);
5700+
Value divided = rewriter.create<AtenFloorDivideOp>(
5701+
loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);
5702+
5703+
auto resultType = cast<BaseTensorType>(op.getType());
5704+
Value modded = rewriter.create<AtenRemainderTensorOp>(
5705+
loc, resultType, divided, inputShapeTensor);
5706+
5707+
rewriter.replaceOp(op, modded);
5708+
return success();
5709+
}
5710+
};
5711+
55265712
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
55275713
namespace {
55285714
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
@@ -10573,6 +10759,7 @@ class DecomposeComplexOpsPass
1057310759
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
1057410760
patterns);
1057510761
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
10762+
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
1057610763
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
1057710764
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
1057810765
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6255,3 +6255,26 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils):
62556255
module.forward(
62566256
tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64)
62576257
)
6258+
6259+
6260+
# ==============================================================================
6261+
6262+
6263+
class AtenNonzero1DModule(torch.nn.Module):
6264+
def __init__(self):
6265+
super().__init__()
6266+
6267+
@export
6268+
@annotate_args(
6269+
[
6270+
None,
6271+
([-1], torch.bool, True),
6272+
]
6273+
)
6274+
def forward(self, x):
6275+
return torch.ops.aten.nonzero(x)
6276+
6277+
6278+
@register_test_case(module_factory=lambda: AtenNonzero1DModule())
6279+
def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils):
6280+
module.forward(torch.tensor([0, 0, 5, 0, 0, 0], dtype=torch.int))

0 commit comments

Comments
 (0)