Skip to content

Commit 641b79c

Browse files
committed
Define all batched dot operations as matmul
New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
1 parent de83717 commit 641b79c

File tree

6 files changed

+213
-81
lines changed

6 files changed

+213
-81
lines changed

pytensor/tensor/math.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3916,23 +3916,7 @@ def logsumexp(x, axis=None, keepdims=False):
39163916
return log(sum(exp(x), axis=axis, keepdims=keepdims))
39173917

39183918

3919-
# Predefine all batched variations of Dot
3920-
_inner_prod = Blockwise(
3921-
_dot,
3922-
signature="(n),(n)->()",
3923-
)
3924-
3925-
_matrix_vec_prod = Blockwise(
3926-
_dot,
3927-
signature="(m,k),(k)->(m)",
3928-
)
3929-
3930-
_vec_matrix_prod = Blockwise(
3931-
_dot,
3932-
signature="(k),(k,n)->(n)",
3933-
)
3934-
3935-
_matrix_matrix_matmul = Blockwise(
3919+
_matmul = Blockwise(
39363920
_dot,
39373921
signature="(m,k),(k,n)->(m,n)",
39383922
gufunc_spec=("numpy.matmul", 2, 1),
@@ -3988,11 +3972,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
39883972
if x1.type.ndim == 1 and x2.type.ndim == 1:
39893973
out = _dot(x1, x2)
39903974
elif x1.type.ndim == 1:
3991-
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
3975+
out = vecmat(x1, x2)
39923976
elif x2.type.ndim == 1:
3993-
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
3977+
out = matvec(x1, x2)
39943978
else:
3995-
out = _matrix_matrix_matmul(x1, x2)
3979+
out = _matmul(x1, x2)
39963980

39973981
if dtype is not None:
39983982
out = out.astype(dtype)
@@ -4042,7 +4026,7 @@ def vecdot(
40424026
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
40434027
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
40444028
"""
4045-
out = _inner_prod(x1, x2)
4029+
out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1))
40464030

40474031
if dtype is not None:
40484032
out = out.astype(dtype)
@@ -4091,7 +4075,7 @@ def matvec(
40914075
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
40924076
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
40934077
"""
4094-
out = _matrix_vec_prod(x1, x2)
4078+
out = matmul(x1, x2[..., None]).squeeze(-1)
40954079

40964080
if dtype is not None:
40974081
out = out.astype(dtype)
@@ -4129,18 +4113,18 @@ def vecmat(
41294113
--------
41304114
>>> import pytensor.tensor as pt
41314115
>>> # Vector-matrix product
4132-
>>> v = pt.vector("v", shape=(3,)) # shape (3,)
4133-
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
4116+
>>> v = pt.vector("v", shape=(3,))
4117+
>>> A = pt.matrix("A", shape=(3, 4))
41344118
>>> result = pt.vecmat(v, A) # shape (4,)
41354119
>>> # Equivalent to numpy.vecmat(v, A)
41364120
>>>
41374121
>>> # Batched vector-matrix product
4138-
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3)
4139-
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
4122+
>>> batched_v = pt.matrix("v", shape=(2, 3))
4123+
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
41404124
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
41414125
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
41424126
"""
4143-
out = _vec_matrix_prod(x1, x2)
4127+
out = matmul(x2.mT, x1[..., None]).squeeze(-1)
41444128

41454129
if dtype is not None:
41464130
out = out.astype(dtype)
@@ -4155,18 +4139,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
41554139
old_y_ndim = old_y.type.ndim
41564140
match (old_x_ndim, old_y_ndim):
41574141
case (1, 1):
4158-
batch_op = _inner_prod
4142+
batch_fn = vecdot
41594143
case (2, 1):
4160-
batch_op = _matrix_vec_prod
4144+
batch_fn = matvec
41614145
case (1, 2):
4162-
batch_op = _vec_matrix_prod
4146+
batch_fn = vecmat
41634147
case (2, 2):
4164-
batch_op = _matrix_matrix_matmul
4148+
batch_fn = matmul
41654149
case _:
41664150
raise ValueError(
41674151
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
41684152
)
4169-
return batch_op(batched_x, batched_y).owner
4153+
return batch_fn(batched_x, batched_y).owner
41704154

41714155

41724156
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):

pytensor/tensor/rewriting/blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
from pytensor.tensor.exceptions import NotScalarConstantError
9999
from pytensor.tensor.math import (
100100
Dot,
101-
_matrix_matrix_matmul,
101+
_matmul,
102102
add,
103103
mul,
104104
neg,
@@ -908,7 +908,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
908908

909909

910910
@register_specialize
911-
@node_rewriter([_matrix_matrix_matmul])
911+
@node_rewriter([_matmul])
912912
def specialize_matmul_to_batched_dot(fgraph, node):
913913
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
914914

pytensor/tensor/rewriting/elemwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
broadcasted_by,
4040
register_canonicalize,
4141
register_specialize,
42+
register_stabilize,
4243
)
4344
from pytensor.tensor.variable import TensorConstant, TensorVariable
4445

@@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):
341342

342343

343344
@register_canonicalize
345+
@register_stabilize
344346
@register_specialize
345347
@node_rewriter([DimShuffle])
346348
def local_dimshuffle_lift(fgraph, node):

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytensor.tensor.blas import Dot22
2727
from pytensor.tensor.blockwise import Blockwise
2828
from pytensor.tensor.elemwise import DimShuffle, Elemwise
29-
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
29+
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
3030
from pytensor.tensor.nlinalg import (
3131
SVD,
3232
KroneckerProduct,
@@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
284284
# This rewrite only applies to matrix Dot
285285
and A.owner.inputs[0].type.ndim == 2
286286
)
287-
or (A.owner.op == _matrix_matrix_matmul)
287+
or (A.owner.op == _matmul)
288288
)
289289
):
290290
return

pytensor/tensor/rewriting/math.py

Lines changed: 116 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@
2828
as_tensor_variable,
2929
cast,
3030
constant,
31+
expand_dims,
3132
get_underlying_scalar_constant_value,
3233
moveaxis,
3334
ones_like,
3435
register_infer_shape,
3536
switch,
3637
zeros_like,
3738
)
38-
from pytensor.tensor.blockwise import Blockwise
3939
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4040
from pytensor.tensor.exceptions import NotScalarConstantError
4141
from pytensor.tensor.extra_ops import broadcast_arrays
@@ -45,10 +45,7 @@
4545
Sum,
4646
_conj,
4747
_dot,
48-
_inner_prod,
49-
_matrix_matrix_matmul,
50-
_matrix_vec_prod,
51-
_vec_matrix_prod,
48+
_matmul,
5249
add,
5350
digamma,
5451
dot,
@@ -197,60 +194,134 @@ def local_lift_transpose_through_dot(fgraph, node):
197194
return ret
198195

199196

200-
@register_stabilize
197+
@register_canonicalize
201198
@register_specialize
202-
@node_rewriter(tracks=[Blockwise])
199+
@node_rewriter(tracks=[_matmul])
203200
def local_batched_matmul_to_core_matmul(fgraph, node):
204-
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
201+
"""Move batch dimensions of matmul operands to core matmul
205202
206-
Example, if x has batch dimensions, but y not:
203+
Example, if x has batch dimensions that don't overlap with batch dimensions of y
207204
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
208205
209-
It also works when y has batch dimensions, but x not.
206+
It also works for batch dimensions of y that don't overlap with batch dimensions of x
210207
"""
211208

212-
# Check whether we have a matmul operation in this node
213-
if not (
214-
isinstance(node.op.core_op, Dot)
215-
and len(node.op.inputs_sig[0]) == 2
216-
and len(node.op.inputs_sig[1]) == 2
217-
):
218-
return None
219-
220209
x, y = node.inputs
221210
batch_ndim = node.op.batch_ndim(node)
222211

223-
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
224-
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
225-
y.type.broadcastable[:-2]
226-
):
227-
x_stacked = x.reshape((-1, x.shape[-1]))
228-
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
229-
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
230-
return [out]
231-
232-
# Otherwise, check if y has batch dimension, but x not
233-
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
234-
x.type.broadcastable[:-2]
235-
):
236-
# For the y batch case we need to first move the batch axes and then reshape
237-
# y.shape == (*b, k, n)
238-
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
239-
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
240-
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
241-
out_stacked_tr = out_stacked.reshape(
242-
(x.shape[-2], *y.shape[:-2], y.shape[-1])
243-
) # (m, *b, n)
244-
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
245-
return [out]
246-
247-
# Both x and y have batch dimensions, nothing to do here
248-
return None
212+
x_axis_to_merge = [
213+
i
214+
for i, (bcast_x, bcast_y) in enumerate(
215+
zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
216+
)
217+
if bcast_y and not bcast_x
218+
]
219+
220+
y_axis_to_merge = [
221+
i
222+
for i, (bcast_x, bcast_y) in enumerate(
223+
zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
224+
)
225+
if bcast_x and not bcast_y
226+
]
227+
228+
if not (x_axis_to_merge or y_axis_to_merge):
229+
return None
230+
231+
x_shape = tuple(x.shape)
232+
y_shape = tuple(y.shape)
233+
x_is_row = x.type.broadcastable[-2]
234+
y_is_col = y.type.broadcastable[-1]
235+
n_x_axis_to_merge = len(x_axis_to_merge)
236+
n_y_axis_to_merge = len(y_axis_to_merge)
237+
n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge
238+
239+
x_stacked, y_stacked = x, y
240+
dims_were_merged = False
241+
242+
if n_x_axis_to_merge:
243+
# ravel batch dimensions of x on the core (m) axis
244+
x_axis_destination = tuple(range(-n_x_axis_to_merge - 2, -2))
245+
x_stacked = moveaxis(x, x_axis_to_merge, x_axis_destination)
246+
if x_is_row:
247+
# x was a row matrix, squeeze it to clean up the graph
248+
x_stacked = x_stacked.squeeze(-2)
249+
if n_x_axis_to_merge > 1 or not x_is_row:
250+
# Ravel moved batch dims together with (m) if needed
251+
x_stacked_shape = tuple(x_stacked.shape)
252+
x_stacked = x_stacked.reshape(
253+
(*x_stacked_shape[: batch_ndim - n_x_axis_to_merge], -1, x_shape[-1])
254+
)
255+
dims_were_merged = True
256+
257+
if n_y_axis_to_merge:
258+
# ravel batch dimensions of y on the core (n) axis
259+
y_axis_destination = tuple(range(-n_y_axis_to_merge - 1, -1))
260+
y_stacked = moveaxis(y, y_axis_to_merge, y_axis_destination)
261+
if y_is_col:
262+
# y was a column matrix, squeeze it to clean up the graph
263+
y_stacked = y_stacked.squeeze(-1)
264+
if n_y_axis_to_merge > 1 or not y_is_col:
265+
# Ravel moved batch dims together with (n) if needed
266+
y_stacked_shape = tuple(y_stacked.shape)
267+
y_stacked = y_stacked.reshape(
268+
(*y_stacked_shape[: batch_ndim - n_y_axis_to_merge], y_shape[-2], -1)
269+
)
270+
dims_were_merged = True
271+
272+
# Squeeze x_dims corresponding to merged dimensions of y
273+
x_axis_to_squeeze = np.array(y_axis_to_merge)
274+
for i in reversed(x_axis_to_merge):
275+
# The corresponding dimensions of y may have shifted when we merged dimensions of x
276+
x_axis_to_squeeze[x_axis_to_squeeze > i] -= 1
277+
x_stacked = x_stacked.squeeze(tuple(x_axis_to_squeeze))
278+
279+
# Same for y
280+
y_axis_to_squeeze = np.array(x_axis_to_merge)
281+
for i in reversed(y_axis_to_merge):
282+
y_axis_to_squeeze[y_axis_to_squeeze > i] -= 1
283+
y_stacked = y_stacked.squeeze(tuple(y_axis_to_squeeze))
284+
285+
out_stacked = x_stacked @ y_stacked
286+
287+
# Split back any merged dimensions
288+
if dims_were_merged:
289+
x_merged_shapes = [x_shape[i] for i in x_axis_to_merge]
290+
if not x_is_row:
291+
# Otherwise we handle that later with expand_dims, which is cleaner
292+
x_merged_shapes.append(x_shape[-2])
293+
y_merged_shapes = [y_shape[i] for i in y_axis_to_merge]
294+
if not y_is_col:
295+
# Otherwise we handle that later with expand_dims, which is cleaner
296+
y_merged_shapes.append(y_shape[-1])
297+
out_stacked_shape = tuple(out_stacked.shape)
298+
out_unstacked = out_stacked.reshape(
299+
(
300+
*out_stacked_shape[: batch_ndim - n_axis_to_merge],
301+
*x_merged_shapes,
302+
*y_merged_shapes,
303+
)
304+
)
305+
else:
306+
out_unstacked = out_stacked
307+
308+
# Add back dummy row, col axis
309+
# We do this separately to avoid the reshape as much as we can
310+
if y_is_col and (n_y_axis_to_merge or dims_were_merged):
311+
out_unstacked = expand_dims(out_unstacked, -1)
312+
if x_is_row and (n_x_axis_to_merge or dims_were_merged):
313+
out_unstacked = expand_dims(out_unstacked, -n_y_axis_to_merge - 2)
314+
315+
# Move batch axis back to their original location
316+
source = range(-n_axis_to_merge - 2, 0)
317+
destination = (*x_axis_to_merge, -2, *y_axis_to_merge, -1)
318+
out = moveaxis(out_unstacked, source, destination)
319+
return [out]
249320

250321

251322
@register_canonicalize
252323
@register_specialize
253-
@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul])
324+
@node_rewriter([_matmul])
254325
def local_blockwise_dot_to_mul(fgraph, node):
255326
"""Rewrite blockwise dots that correspond to multiplication without summation.
256327

0 commit comments

Comments
 (0)