Skip to content

Commit

Permalink
Fix AtenArangeStartStepOp dynamic end support
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Nov 20, 2024
1 parent 35e20e0 commit f7cd0fc
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5604,11 +5604,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
subtracted, c(0));

// iota = torch.tensor(range(len(t))) * nonzero_mask.int()
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value end =
rewriter.create<AtenSizeIntOp>(loc, flattenedInput, /*dim=*/constZero);
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, cumulativeSumType, c(0),
rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
one, noneCst, noneCst, noneCst, noneCst);
loc, cumulativeSumType, c(0), end, one, noneCst, noneCst, noneCst,
noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
rangeTensor, intMask);

Expand Down

0 comments on commit f7cd0fc

Please sign in to comment.