Skip to content

Commit 4829455

Browse files
ricardoV94jessegrabowski
authored andcommitted
Fix bug in local_blockwise_advanced_inc_subtensor
1 parent 709f745 commit 4829455

File tree

2 files changed

+39
-28
lines changed

2 files changed

+39
-28
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
alloc,
2626
cast,
2727
concatenate,
28+
expand_dims,
2829
get_scalar_constant_value,
2930
get_underlying_scalar_constant_value,
3031
register_infer_shape,
@@ -1576,7 +1577,15 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
15761577
x = alloc(x, *batch_shape, *core_shape)
15771578

15781579
new_idxs = [slice(None)] * batch_ndim + new_idxs
1579-
symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:]
1580+
x_view = x[tuple(new_idxs)]
1581+
1582+
# We need to introduce any implicit expand_dims on core dimension of y
1583+
y_core_ndim = y.type.ndim - batch_ndim
1584+
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0:
1585+
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
1586+
y = expand_dims(y, missing_axes)
1587+
1588+
symbolic_idxs = x_view.owner.inputs[1:]
15801589
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
15811590
copy_stack_trace(node.outputs, new_out)
15821591
return new_out

tests/tensor/rewriting/test_subtensor.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,10 +1788,24 @@ def test_local_uint_constant_indices():
17881788
assert new_index.type.dtype == "uint8"
17891789

17901790

1791+
@pytest.mark.parametrize("core_y_implicitly_batched", (False, True))
17911792
@pytest.mark.parametrize("set_instead_of_inc", (True, False))
1792-
def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
1793+
def test_local_blockwise_advanced_inc_subtensor(
1794+
set_instead_of_inc, core_y_implicitly_batched
1795+
):
1796+
rng = np.random.default_rng([1764, set_instead_of_inc, core_y_implicitly_batched])
1797+
1798+
def np_inplace_f(x, idx, y):
1799+
if core_y_implicitly_batched:
1800+
y = y[..., None]
1801+
if set_instead_of_inc:
1802+
x[idx] = y
1803+
else:
1804+
x[idx] += y
1805+
1806+
core_y_shape = () if core_y_implicitly_batched else (3,)
17931807
core_x = tensor("x", shape=(6,))
1794-
core_y = tensor("y", shape=(3,))
1808+
core_y = tensor("y", shape=core_y_shape, dtype=int)
17951809
core_idxs = [0, 2, 4]
17961810
if set_instead_of_inc:
17971811
core_graph = set_subtensor(core_x[core_idxs], core_y)
@@ -1800,7 +1814,7 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
18001814

18011815
# Only x is batched
18021816
x = tensor("x", shape=(5, 2, 6))
1803-
y = tensor("y", shape=(3,))
1817+
y = tensor("y", shape=core_y_shape, dtype=int)
18041818
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
18051819
assert isinstance(out.owner.op, Blockwise)
18061820

@@ -1810,17 +1824,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
18101824
)
18111825

18121826
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
1813-
test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype)
1827+
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
18141828
expected_out = test_x.copy()
1815-
if set_instead_of_inc:
1816-
expected_out[:, :, core_idxs] = test_y
1817-
else:
1818-
expected_out[:, :, core_idxs] += test_y
1829+
np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y)
18191830
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
18201831

18211832
# Only y is batched
18221833
x = tensor("y", shape=(6,))
1823-
y = tensor("y", shape=(2, 3))
1834+
y = tensor("y", shape=(2, *core_y_shape), dtype=int)
18241835
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
18251836
assert isinstance(out.owner.op, Blockwise)
18261837

@@ -1830,17 +1841,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
18301841
)
18311842

18321843
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
1833-
test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype)
1844+
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
18341845
expected_out = np.ones((2, *x.type.shape))
1835-
if set_instead_of_inc:
1836-
expected_out[:, core_idxs] = test_y
1837-
else:
1838-
expected_out[:, core_idxs] += test_y
1846+
np_inplace_f(expected_out, np.s_[:, core_idxs], test_y)
18391847
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
18401848

18411849
# Both x and y are batched, and do not need to be broadcasted
18421850
x = tensor("y", shape=(2, 6))
1843-
y = tensor("y", shape=(2, 3))
1851+
y = tensor("y", shape=(2, *core_y_shape), dtype=int)
18441852
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
18451853
assert isinstance(out.owner.op, Blockwise)
18461854

@@ -1850,17 +1858,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
18501858
)
18511859

18521860
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
1853-
test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype)
1861+
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
18541862
expected_out = test_x.copy()
1855-
if set_instead_of_inc:
1856-
expected_out[:, core_idxs] = test_y
1857-
else:
1858-
expected_out[:, core_idxs] += test_y
1863+
np_inplace_f(expected_out, np.s_[:, core_idxs], test_y)
18591864
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
18601865

18611866
# Both x and y are batched, but must be broadcasted
18621867
x = tensor("y", shape=(5, 1, 6))
1863-
y = tensor("y", shape=(1, 2, 3))
1868+
y = tensor("y", shape=(1, 2, *core_y_shape), dtype=int)
18641869
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
18651870
assert isinstance(out.owner.op, Blockwise)
18661871

@@ -1870,16 +1875,13 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
18701875
)
18711876

18721877
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
1873-
test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype)
1878+
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
18741879
final_shape = (
1875-
*np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]),
1880+
*np.broadcast_shapes(x.type.shape[:2], y.type.shape[:2]),
18761881
x.type.shape[-1],
18771882
)
18781883
expected_out = np.broadcast_to(test_x, final_shape).copy()
1879-
if set_instead_of_inc:
1880-
expected_out[:, :, core_idxs] = test_y
1881-
else:
1882-
expected_out[:, :, core_idxs] += test_y
1884+
np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y)
18831885
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
18841886

18851887

0 commit comments

Comments
 (0)