@@ -1788,10 +1788,24 @@ def test_local_uint_constant_indices():
1788
1788
assert new_index .type .dtype == "uint8"
1789
1789
1790
1790
1791
+ @pytest .mark .parametrize ("core_y_implicitly_batched" , (False , True ))
1791
1792
@pytest .mark .parametrize ("set_instead_of_inc" , (True , False ))
1792
- def test_local_blockwise_advanced_inc_subtensor (set_instead_of_inc ):
1793
+ def test_local_blockwise_advanced_inc_subtensor (
1794
+ set_instead_of_inc , core_y_implicitly_batched
1795
+ ):
1796
+ rng = np .random .default_rng ([1764 , set_instead_of_inc , core_y_implicitly_batched ])
1797
+
1798
+ def np_inplace_f (x , idx , y ):
1799
+ if core_y_implicitly_batched :
1800
+ y = y [..., None ]
1801
+ if set_instead_of_inc :
1802
+ x [idx ] = y
1803
+ else :
1804
+ x [idx ] += y
1805
+
1806
+ core_y_shape = () if core_y_implicitly_batched else (3 ,)
1793
1807
core_x = tensor ("x" , shape = (6 ,))
1794
- core_y = tensor ("y" , shape = ( 3 ,) )
1808
+ core_y = tensor ("y" , shape = core_y_shape , dtype = int )
1795
1809
core_idxs = [0 , 2 , 4 ]
1796
1810
if set_instead_of_inc :
1797
1811
core_graph = set_subtensor (core_x [core_idxs ], core_y )
@@ -1800,7 +1814,7 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
1800
1814
1801
1815
# Only x is batched
1802
1816
x = tensor ("x" , shape = (5 , 2 , 6 ))
1803
- y = tensor ("y" , shape = ( 3 ,) )
1817
+ y = tensor ("y" , shape = core_y_shape , dtype = int )
1804
1818
out = vectorize_graph (core_graph , replace = {core_x : x , core_y : y })
1805
1819
assert isinstance (out .owner .op , Blockwise )
1806
1820
@@ -1810,17 +1824,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
1810
1824
)
1811
1825
1812
1826
test_x = np .ones (x .type .shape , dtype = x .type .dtype )
1813
- test_y = np . array ([ 5 , 6 , 7 ]). astype ( dtype = core_y .type .dtype )
1827
+ test_y = rng . integers ( 1 , 10 , size = y . type . shape , dtype = y .type .dtype )
1814
1828
expected_out = test_x .copy ()
1815
- if set_instead_of_inc :
1816
- expected_out [:, :, core_idxs ] = test_y
1817
- else :
1818
- expected_out [:, :, core_idxs ] += test_y
1829
+ np_inplace_f (expected_out , np .s_ [:, :, core_idxs ], test_y )
1819
1830
np .testing .assert_allclose (fn (test_x , test_y ), expected_out )
1820
1831
1821
1832
# Only y is batched
1822
1833
x = tensor ("y" , shape = (6 ,))
1823
- y = tensor ("y" , shape = (2 , 3 ) )
1834
+ y = tensor ("y" , shape = (2 , * core_y_shape ), dtype = int )
1824
1835
out = vectorize_graph (core_graph , replace = {core_x : x , core_y : y })
1825
1836
assert isinstance (out .owner .op , Blockwise )
1826
1837
@@ -1830,17 +1841,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
1830
1841
)
1831
1842
1832
1843
test_x = np .ones (x .type .shape , dtype = x .type .dtype )
1833
- test_y = np . array ([[ 3 , 3 , 3 ], [ 5 , 6 , 7 ]]). astype ( dtype = core_y .type .dtype )
1844
+ test_y = rng . integers ( 1 , 10 , size = y . type . shape , dtype = y .type .dtype )
1834
1845
expected_out = np .ones ((2 , * x .type .shape ))
1835
- if set_instead_of_inc :
1836
- expected_out [:, core_idxs ] = test_y
1837
- else :
1838
- expected_out [:, core_idxs ] += test_y
1846
+ np_inplace_f (expected_out , np .s_ [:, core_idxs ], test_y )
1839
1847
np .testing .assert_allclose (fn (test_x , test_y ), expected_out )
1840
1848
1841
1849
# Both x and y are batched, and do not need to be broadcasted
1842
1850
x = tensor ("y" , shape = (2 , 6 ))
1843
- y = tensor ("y" , shape = (2 , 3 ) )
1851
+ y = tensor ("y" , shape = (2 , * core_y_shape ), dtype = int )
1844
1852
out = vectorize_graph (core_graph , replace = {core_x : x , core_y : y })
1845
1853
assert isinstance (out .owner .op , Blockwise )
1846
1854
@@ -1850,17 +1858,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
1850
1858
)
1851
1859
1852
1860
test_x = np .ones (x .type .shape , dtype = x .type .dtype )
1853
- test_y = np . array ([[ 5 , 6 , 7 ], [ 3 , 3 , 3 ]]). astype ( dtype = core_y .type .dtype )
1861
+ test_y = rng . integers ( 1 , 10 , size = y . type . shape , dtype = y .type .dtype )
1854
1862
expected_out = test_x .copy ()
1855
- if set_instead_of_inc :
1856
- expected_out [:, core_idxs ] = test_y
1857
- else :
1858
- expected_out [:, core_idxs ] += test_y
1863
+ np_inplace_f (expected_out , np .s_ [:, core_idxs ], test_y )
1859
1864
np .testing .assert_allclose (fn (test_x , test_y ), expected_out )
1860
1865
1861
1866
# Both x and y are batched, but must be broadcasted
1862
1867
x = tensor ("y" , shape = (5 , 1 , 6 ))
1863
- y = tensor ("y" , shape = (1 , 2 , 3 ) )
1868
+ y = tensor ("y" , shape = (1 , 2 , * core_y_shape ), dtype = int )
1864
1869
out = vectorize_graph (core_graph , replace = {core_x : x , core_y : y })
1865
1870
assert isinstance (out .owner .op , Blockwise )
1866
1871
@@ -1870,16 +1875,13 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
1870
1875
)
1871
1876
1872
1877
test_x = np .ones (x .type .shape , dtype = x .type .dtype )
1873
- test_y = np . array ([[[ 5 , 6 , 7 ], [ 3 , 3 , 3 ]]]). astype ( dtype = core_y .type .dtype )
1878
+ test_y = rng . integers ( 1 , 10 , size = y . type . shape , dtype = y .type .dtype )
1874
1879
final_shape = (
1875
- * np .broadcast_shapes (x .type .shape [:- 1 ], y .type .shape [:- 1 ]),
1880
+ * np .broadcast_shapes (x .type .shape [:2 ], y .type .shape [:2 ]),
1876
1881
x .type .shape [- 1 ],
1877
1882
)
1878
1883
expected_out = np .broadcast_to (test_x , final_shape ).copy ()
1879
- if set_instead_of_inc :
1880
- expected_out [:, :, core_idxs ] = test_y
1881
- else :
1882
- expected_out [:, :, core_idxs ] += test_y
1884
+ np_inplace_f (expected_out , np .s_ [:, :, core_idxs ], test_y )
1883
1885
np .testing .assert_allclose (fn (test_x , test_y ), expected_out )
1884
1886
1885
1887
0 commit comments