diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index be16c4fb61..0ca6e0b452 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -724,6 +724,44 @@ def local_useless_subtensor(fgraph, node): return [node.inputs[0]] +@register_canonicalize +@node_rewriter([Subtensor]) +def local_convert_negative_indices(fgraph, node): + """Convert negative indices in `Subtensor` with static length to positive indices.""" + x, *raw_idxs = node.inputs + idxs = indices_from_subtensor(raw_idxs, node.op.idx_list) + + new_idxs = None + for i, (dim_length, idx) in enumerate(zip(x.type.shape, idxs)): + if ( + dim_length is None + or isinstance(idx, slice) + or not isinstance(idx, Constant) + ): + continue + + val = idx.data + if val >= 0: + continue + + new_val = val + dim_length + if new_val < 0: + # This is an invalid index, keep original to not confuse the user + return None + + if new_idxs is None: + new_idxs = list(idxs) + new_idxs[i] = new_val + + if new_idxs is None: + # No negative indices to convert + return None + + new_subtensor = x[tuple(new_idxs)] + copy_stack_trace(node.outputs, new_subtensor) + return [new_subtensor] + + @register_canonicalize @register_specialize @node_rewriter([AdvancedSubtensor1]) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 0be51819d4..4cb2b0f4cd 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -1992,3 +1992,20 @@ def test_extract_diag_of_diagonal_set_subtensor(): expected_outs.append(outs[-1]) assert equal_computations(rewritten_outs, expected_outs) + + +def test_local_convert_negative_indices(): + x = pt.tensor("x", shape=(None, 3, 1)) + + # Dim length is unknown rewrite can't be applied + rewritten_out = rewrite_graph(x[-2]) + assert equal_computations([rewritten_out], [x[-2]]) + + # Rewrite applies + rewritten_out = rewrite_graph(x[:, -2]) + assert equal_computations([rewritten_out], [x[:, 1]]) + + # Rewrite doesn't apply because index is invalid + # TODO: If Subtensor decides to raise on make_node, this test can be removed + rewritten_out = rewrite_graph(x[:, :, -2]) + assert equal_computations([rewritten_out], [x[:, :, -2]]) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 933d1a1577..ccfa033859 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -202,7 +202,7 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn): out = original_fn(x) expected_opt_out = expected_fn(x) - opt_out = rewrite_graph(out) + opt_out = rewrite_graph(out, exclude=("local_convert_negative_indices",)) assert equal_computations([opt_out], [expected_opt_out]), debugprint( [expected_opt_out, opt_out], print_type=True )