diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index de5faa9794a83..131fa4f046d0b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5604,11 +5604,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { subtracted, c(0)); // iota = torch.tensor(range(len(t))) * nonzero_mask.int() + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value end = + rewriter.create(loc, flattenedInput, /*dim=*/constZero); Value rangeTensor = rewriter.create( - loc, cumulativeSumType, c(0), - rewriter.create( - loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])), - one, noneCst, noneCst, noneCst, noneCst); + loc, cumulativeSumType, c(0), end, one, noneCst, noneCst, noneCst, + noneCst); Value multiplied = rewriter.create(loc, cumulativeSumType, rangeTensor, intMask);