File tree Expand file tree Collapse file tree 2 files changed +55
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +55
-0
lines changed Original file line number Diff line number Diff line change @@ -724,6 +724,44 @@ def local_useless_subtensor(fgraph, node):
724
724
return [node .inputs [0 ]]
725
725
726
726
727
+ @register_canonicalize
728
+ @node_rewriter ([Subtensor ])
729
+ def local_convert_negative_indices (fgraph , node ):
730
+ """Convert negative indices in `Subtensor` with static length to positive indices."""
731
+ x , * raw_idxs = node .inputs
732
+ idxs = indices_from_subtensor (raw_idxs , node .op .idx_list )
733
+
734
+ new_idxs = None
735
+ for i , (dim_length , idx ) in enumerate (zip (x .type .shape , idxs )):
736
+ if (
737
+ dim_length is None
738
+ or isinstance (idx , slice )
739
+ or not isinstance (idx , Constant )
740
+ ):
741
+ continue
742
+
743
+ val = idx .data
744
+ if val >= 0 :
745
+ continue
746
+
747
+ new_val = val + dim_length
748
+ if new_val < 0 :
749
+ # This is an invalid index, keep original to not confuse the user
750
+ return None
751
+
752
+ if new_idxs is None :
753
+ new_idxs = list (idxs )
754
+ new_idxs [i ] = new_val
755
+
756
+ if new_idxs is None :
757
+ # No negative indices to convert
758
+ return None
759
+
760
+ new_subtensor = x [tuple (new_idxs )]
761
+ copy_stack_trace (node .outputs , new_subtensor )
762
+ return [new_subtensor ]
763
+
764
+
727
765
@register_canonicalize
728
766
@register_specialize
729
767
@node_rewriter ([AdvancedSubtensor1 ])
Original file line number Diff line number Diff line change @@ -1992,3 +1992,20 @@ def test_extract_diag_of_diagonal_set_subtensor():
1992
1992
expected_outs .append (outs [- 1 ])
1993
1993
1994
1994
assert equal_computations (rewritten_outs , expected_outs )
1995
+
1996
+
1997
+ def test_local_convert_negative_indices ():
1998
+ x = pt .tensor ("x" , shape = (None , 3 , 1 ))
1999
+
2000
+ # Dim length is unknown rewrite can't be applied
2001
+ rewritten_out = rewrite_graph (x [- 2 ])
2002
+ assert equal_computations ([rewritten_out ], [x [- 2 ]])
2003
+
2004
+ # Rewrite applies
2005
+ rewritten_out = rewrite_graph (x [:, - 2 ])
2006
+ assert equal_computations ([rewritten_out ], [x [:, 1 ]])
2007
+
2008
+ # Rewrite doesn't apply because index is invalid
2009
+ # TODO: If Subtensor decides to raise on make_node, this test can be removed
2010
+ rewritten_out = rewrite_graph (x [:, :, - 2 ])
2011
+ assert equal_computations ([rewritten_out ], [x [:, :, - 2 ]])
You can’t perform that action at this time.
0 commit comments