From f7cd0fc6158b742ed238327ec50c1e37dcc730c7 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Tue, 19 Nov 2024 19:41:29 -0800 Subject: [PATCH] Fix AtenArangeStartStepOp dynamic end support --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index de5faa9794a8..131fa4f046d0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5604,11 +5604,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { subtracted, c(0)); // iota = torch.tensor(range(len(t))) * nonzero_mask.int() + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value end = + rewriter.create(loc, flattenedInput, /*dim=*/constZero); Value rangeTensor = rewriter.create( - loc, cumulativeSumType, c(0), - rewriter.create( - loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])), - one, noneCst, noneCst, noneCst, noneCst); + loc, cumulativeSumType, c(0), end, one, noneCst, noneCst, noneCst, + noneCst); Value multiplied = rewriter.create(loc, cumulativeSumType, rangeTensor, intMask);