Skip to content

Commit 5516080

Browse files
committed
Canonicalize subtensor negative integer indices
1 parent b4522d2 commit 5516080

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,44 @@ def local_useless_subtensor(fgraph, node):
724724
return [node.inputs[0]]
725725

726726

727+
@register_canonicalize
728+
@node_rewriter([Subtensor])
729+
def local_convert_negative_indices(fgraph, node):
730+
"""Convert negative indices in `Subtensor` with static length to positive indices."""
731+
x, *raw_idxs = node.inputs
732+
idxs = indices_from_subtensor(raw_idxs, node.op.idx_list)
733+
734+
new_idxs = None
735+
for i, (dim_length, idx) in enumerate(zip(x.type.shape, idxs)):
736+
if (
737+
dim_length is None
738+
or isinstance(idx, slice)
739+
or not isinstance(idx, Constant)
740+
):
741+
continue
742+
743+
val = idx.data
744+
if val >= 0:
745+
continue
746+
747+
new_val = val + dim_length
748+
if new_val < 0:
749+
# This is an invalid index, keep original to not confuse the user
750+
return None
751+
752+
if new_idxs is None:
753+
new_idxs = list(idxs)
754+
new_idxs[i] = new_val
755+
756+
if new_idxs is None:
757+
# No negative indices to convert
758+
return None
759+
760+
new_subtensor = x[tuple(new_idxs)]
761+
copy_stack_trace(node.outputs, new_subtensor)
762+
return [new_subtensor]
763+
764+
727765
@register_canonicalize
728766
@register_specialize
729767
@node_rewriter([AdvancedSubtensor1])

tests/tensor/rewriting/test_subtensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,3 +1992,20 @@ def test_extract_diag_of_diagonal_set_subtensor():
19921992
expected_outs.append(outs[-1])
19931993

19941994
assert equal_computations(rewritten_outs, expected_outs)
1995+
1996+
1997+
def test_local_convert_negative_indices():
1998+
x = pt.tensor("x", shape=(None, 3, 1))
1999+
2000+
# Dim length is unknown rewrite can't be applied
2001+
rewritten_out = rewrite_graph(x[-2])
2002+
assert equal_computations([rewritten_out], [x[-2]])
2003+
2004+
# Rewrite applies
2005+
rewritten_out = rewrite_graph(x[:, -2])
2006+
assert equal_computations([rewritten_out], [x[:, 1]])
2007+
2008+
# Rewrite doesn't apply because index is invalid
2009+
# TODO: If Subtensor decides to raise on make_node, this test can be removed
2010+
rewritten_out = rewrite_graph(x[:, :, -2])
2011+
assert equal_computations([rewritten_out], [x[:, :, -2]])

0 commit comments

Comments
 (0)