Skip to content

Canonicalize subtensor negative integer indices #1541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,44 @@ def local_useless_subtensor(fgraph, node):
return [node.inputs[0]]


@register_canonicalize
@node_rewriter([Subtensor])
def local_convert_negative_indices(fgraph, node):
"""Convert negative indices in `Subtensor` with static length to positive indices."""
x, *raw_idxs = node.inputs
idxs = indices_from_subtensor(raw_idxs, node.op.idx_list)

new_idxs = None
for i, (dim_length, idx) in enumerate(zip(x.type.shape, idxs)):
if (
dim_length is None
or isinstance(idx, slice)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also canonicalize negative slices if they're static, is that out of scope here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Slices touch on the whole subtensor_merge monster and I don't want to touch it for now

or not isinstance(idx, Constant)
):
continue

val = idx.data
if val >= 0:
continue

new_val = val + dim_length
if new_val < 0:
# This is an invalid index, keep original to not confuse the user
return None

if new_idxs is None:
new_idxs = list(idxs)
new_idxs[i] = new_val

if new_idxs is None:
# No negative indices to convert
return None

new_subtensor = x[tuple(new_idxs)]
copy_stack_trace(node.outputs, new_subtensor)
return [new_subtensor]


@register_canonicalize
@register_specialize
@node_rewriter([AdvancedSubtensor1])
Expand Down
17 changes: 17 additions & 0 deletions tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,3 +1992,20 @@ def test_extract_diag_of_diagonal_set_subtensor():
expected_outs.append(outs[-1])

assert equal_computations(rewritten_outs, expected_outs)


def test_local_convert_negative_indices():
x = pt.tensor("x", shape=(None, 3, 1))

# Dim length is unknown rewrite can't be applied
rewritten_out = rewrite_graph(x[-2])
assert equal_computations([rewritten_out], [x[-2]])

# Rewrite applies
rewritten_out = rewrite_graph(x[:, -2])
assert equal_computations([rewritten_out], [x[:, 1]])

# Rewrite doesn't apply because index is invalid
# TODO: If Subtensor decides to raise on make_node, this test can be removed
rewritten_out = rewrite_graph(x[:, :, -2])
assert equal_computations([rewritten_out], [x[:, :, -2]])
2 changes: 1 addition & 1 deletion tests/tensor/rewriting/test_subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):

out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out)
opt_out = rewrite_graph(out, exclude=("local_convert_negative_indices",))
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
Expand Down