Skip to content

Commit 5210f2d

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement dot for XTensorVariables (#1475)
1 parent 2de3a27 commit 5210f2d

File tree

5 files changed

+329
-2
lines changed

5 files changed

+329
-2
lines changed

pytensor/xtensor/math.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import sys
2+
from collections.abc import Iterable
3+
from types import EllipsisType
24

35
import numpy as np
46

57
import pytensor.scalar as ps
68
from pytensor import config
9+
from pytensor.graph.basic import Apply
710
from pytensor.scalar import ScalarOp
8-
from pytensor.scalar.basic import _cast_mapping
9-
from pytensor.xtensor.basic import as_xtensor
11+
from pytensor.scalar.basic import _cast_mapping, upcast
12+
from pytensor.xtensor.basic import XOp, as_xtensor
13+
from pytensor.xtensor.type import xtensor
1014
from pytensor.xtensor.vectorization import XElemwise
1115

1216

@@ -134,3 +138,110 @@ def cast(x, dtype):
134138
if dtype not in _xelemwise_cast_op:
135139
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
136140
return _xelemwise_cast_op[dtype](x)
141+
142+
143+
class XDot(XOp):
144+
"""Matrix multiplication between two XTensorVariables.
145+
146+
This operation performs matrix multiplication between two tensors, automatically
147+
aligning and contracting dimensions. The behavior matches xarray's dot operation.
148+
149+
Parameters
150+
----------
151+
dims : tuple of str
152+
The dimensions to contract over. If None, will contract over all matching dimensions.
153+
"""
154+
155+
__props__ = ("dims",)
156+
157+
def __init__(self, dims: Iterable[str]):
158+
self.dims = dims
159+
super().__init__()
160+
161+
def make_node(self, x, y):
162+
x = as_xtensor(x)
163+
y = as_xtensor(y)
164+
165+
x_shape_dict = dict(zip(x.type.dims, x.type.shape))
166+
y_shape_dict = dict(zip(y.type.dims, y.type.shape))
167+
168+
# Check for dimension size mismatches (concrete only)
169+
for dim in self.dims:
170+
x_shape = x_shape_dict.get(dim, None)
171+
y_shape = y_shape_dict.get(dim, None)
172+
if (
173+
isinstance(x_shape, int)
174+
and isinstance(y_shape, int)
175+
and x_shape != y_shape
176+
):
177+
raise ValueError(f"Size of dim '{dim}' does not match")
178+
179+
# Determine output dimensions
180+
shape_dict = {**x_shape_dict, **y_shape_dict}
181+
out_dims = tuple(d for d in shape_dict if d not in self.dims)
182+
183+
# Determine output shape
184+
out_shape = tuple(shape_dict[d] for d in out_dims)
185+
186+
# Determine output dtype
187+
out_dtype = upcast(x.type.dtype, y.type.dtype)
188+
189+
out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
190+
return Apply(self, [x, y], [out])
191+
192+
193+
def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
194+
"""Matrix multiplication between two XTensorVariables.
195+
196+
This operation performs matrix multiplication between two tensors, automatically
197+
aligning and contracting dimensions. The behavior matches xarray's dot operation.
198+
199+
Parameters
200+
----------
201+
x : XTensorVariable
202+
First input tensor
203+
y : XTensorVariable
204+
Second input tensor
205+
dim : str, Iterable[Hashable], EllipsisType, or None, optional
206+
The dimensions to contract over. If None, will contract over all matching dimensions.
207+
If Ellipsis (...), will contract over all dimensions.
208+
209+
Returns
210+
-------
211+
XTensorVariable
212+
The result of the matrix multiplication.
213+
214+
Examples
215+
--------
216+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
217+
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
218+
>>> z = dot(x, y) # Result has dimensions ("a", "c")
219+
>>> z = dot(x, y, dim=...) # Contract over all dimensions
220+
"""
221+
x = as_xtensor(x)
222+
y = as_xtensor(y)
223+
224+
x_dims = set(x.type.dims)
225+
y_dims = set(y.type.dims)
226+
intersection = x_dims & y_dims
227+
union = x_dims | y_dims
228+
229+
# Canonicalize dims
230+
if dim is None:
231+
dim_set = intersection
232+
elif dim is ...:
233+
dim_set = union
234+
elif isinstance(dim, str):
235+
dim_set = {dim}
236+
elif isinstance(dim, Iterable):
237+
dim_set = set(dim)
238+
239+
# Validate provided dims
240+
# Check if any dimension is not found in either input
241+
for d in dim_set:
242+
if d not in union:
243+
raise ValueError(f"Dimension {d} not found in either input")
244+
245+
result = XDot(dims=tuple(dim_set))(x, y)
246+
247+
return result
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytensor.xtensor.rewriting.basic
22
import pytensor.xtensor.rewriting.indexing
3+
import pytensor.xtensor.rewriting.math
34
import pytensor.xtensor.rewriting.reduction
45
import pytensor.xtensor.rewriting.shape
56
import pytensor.xtensor.rewriting.vectorization

pytensor/xtensor/rewriting/math.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from string import ascii_lowercase
2+
3+
from pytensor.graph import node_rewriter
4+
from pytensor.tensor import einsum
5+
from pytensor.tensor.shape import specify_shape
6+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
7+
from pytensor.xtensor.math import XDot
8+
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
9+
10+
11+
@register_lower_xtensor
12+
@node_rewriter(tracks=[XDot])
13+
def lower_dot(fgraph, node):
14+
"""Rewrite XDot to tensor.dot.
15+
16+
This rewrite converts an XDot operation to a tensor-based dot operation,
17+
handling dimension alignment and contraction.
18+
"""
19+
[x, y] = node.inputs
20+
[out] = node.outputs
21+
22+
# Convert inputs to tensors
23+
x_tensor = tensor_from_xtensor(x)
24+
y_tensor = tensor_from_xtensor(y)
25+
26+
# Collect all dimension names across inputs and output
27+
all_dims = list(
28+
dict.fromkeys(x.type.dims + y.type.dims + out.type.dims)
29+
) # preserve order
30+
if len(all_dims) > len(ascii_lowercase):
31+
raise ValueError("Too many dimensions to map to einsum subscripts")
32+
33+
dim_to_char = dict(zip(all_dims, ascii_lowercase))
34+
35+
# Build einsum string
36+
x_subs = "".join(dim_to_char[d] for d in x.type.dims)
37+
y_subs = "".join(dim_to_char[d] for d in y.type.dims)
38+
out_subs = "".join(dim_to_char[d] for d in out.type.dims)
39+
einsum_str = f"{x_subs},{y_subs}->{out_subs}"
40+
41+
# Perform the einsum operation
42+
out_tensor = einsum(einsum_str, x_tensor, y_tensor)
43+
44+
# Reshape to match the output shape
45+
out_tensor = specify_shape(out_tensor, out.type.shape)
46+
47+
return [xtensor_from_tensor(out_tensor, out.type.dims)]

pytensor/xtensor/type.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,10 @@ def stack(self, dim, **dims):
649649
def unstack(self, dim, **dims):
650650
return px.shape.unstack(self, dim, **dims)
651651

652+
def dot(self, other, dim=None):
653+
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
654+
return px.math.dot(self, other, dim=dim)
655+
652656

653657
class XTensorConstantSignature(TensorConstantSignature):
654658
pass

tests/xtensor/test_math.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,167 @@ def test_cast():
151151
yc64 = x.astype("complex64")
152152
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
153153
yc64.astype("float64")
154+
155+
156+
def test_dot():
157+
"""Test basic dot product operations."""
158+
# Test matrix-vector dot product (with multiple-letter dim names)
159+
x = xtensor("x", dims=("aa", "bb"), shape=(2, 3))
160+
y = xtensor("y", dims=("bb",), shape=(3,))
161+
z = x.dot(y)
162+
fn = xr_function([x, y], z)
163+
164+
x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb"))
165+
y_test = DataArray(np.ones(3), dims=("bb",))
166+
z_test = fn(x_test, y_test)
167+
expected = x_test.dot(y_test)
168+
xr_assert_allclose(z_test, expected)
169+
170+
# Test matrix-vector dot product with ellipsis
171+
z = x.dot(y, dim=...)
172+
fn = xr_function([x, y], z)
173+
z_test = fn(x_test, y_test)
174+
expected = x_test.dot(y_test, dim=...)
175+
xr_assert_allclose(z_test, expected)
176+
177+
# Test matrix-matrix dot product
178+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
179+
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
180+
z = x.dot(y)
181+
fn = xr_function([x, y], z)
182+
183+
x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b"))
184+
y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c"))
185+
z_test = fn(x_test, y_test)
186+
expected = x_test.dot(y_test)
187+
xr_assert_allclose(z_test, expected)
188+
189+
# Test matrix-matrix dot product with string dim
190+
z = x.dot(y, dim="b")
191+
fn = xr_function([x, y], z)
192+
z_test = fn(x_test, y_test)
193+
expected = x_test.dot(y_test, dim="b")
194+
xr_assert_allclose(z_test, expected)
195+
196+
# Test matrix-matrix dot product with list of dims
197+
z = x.dot(y, dim=["b"])
198+
fn = xr_function([x, y], z)
199+
z_test = fn(x_test, y_test)
200+
expected = x_test.dot(y_test, dim=["b"])
201+
xr_assert_allclose(z_test, expected)
202+
203+
# Test matrix-matrix dot product with ellipsis
204+
z = x.dot(y, dim=...)
205+
fn = xr_function([x, y], z)
206+
z_test = fn(x_test, y_test)
207+
expected = x_test.dot(y_test, dim=...)
208+
xr_assert_allclose(z_test, expected)
209+
210+
# Test a case where there are two dimensions to sum over
211+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
212+
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
213+
z = x.dot(y)
214+
fn = xr_function([x, y], z)
215+
216+
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
217+
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
218+
z_test = fn(x_test, y_test)
219+
expected = x_test.dot(y_test)
220+
xr_assert_allclose(z_test, expected)
221+
222+
# Same but with explicit dimensions
223+
z = x.dot(y, dim=["b", "c"])
224+
fn = xr_function([x, y], z)
225+
z_test = fn(x_test, y_test)
226+
expected = x_test.dot(y_test, dim=["b", "c"])
227+
xr_assert_allclose(z_test, expected)
228+
229+
# Same but with ellipses
230+
z = x.dot(y, dim=...)
231+
fn = xr_function([x, y], z)
232+
z_test = fn(x_test, y_test)
233+
expected = x_test.dot(y_test, dim=...)
234+
xr_assert_allclose(z_test, expected)
235+
236+
# Dot product with sum
237+
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
238+
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
239+
expected = x_test.dot(y_test, dim=("a", "b", "c"))
240+
241+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
242+
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
243+
z = x.dot(y, dim=("a", "b", "c"))
244+
fn = xr_function([x, y], z)
245+
z_test = fn(x_test, y_test)
246+
xr_assert_allclose(z_test, expected)
247+
248+
# Dot product with sum in the middle
249+
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
250+
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
251+
expected = x_test.dot(y_test, dim=("b", "d"))
252+
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5))
253+
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6))
254+
z = x.dot(y, dim=("b", "d"))
255+
fn = xr_function([x, y], z)
256+
z_test = fn(x_test, y_test)
257+
xr_assert_allclose(z_test, expected)
258+
259+
# Same but with first two dims
260+
expected = x_test.dot(y_test, dim=["a", "b"])
261+
z = x.dot(y, dim=["a", "b"])
262+
fn = xr_function([x, y], z)
263+
z_test = fn(x_test, y_test)
264+
xr_assert_allclose(z_test, expected)
265+
266+
# Same but with last two
267+
expected = x_test.dot(y_test, dim=["d", "e"])
268+
z = x.dot(y, dim=["d", "e"])
269+
fn = xr_function([x, y], z)
270+
z_test = fn(x_test, y_test)
271+
xr_assert_allclose(z_test, expected)
272+
273+
# Same but with every other dim
274+
expected = x_test.dot(y_test, dim=["a", "c", "e"])
275+
z = x.dot(y, dim=["a", "c", "e"])
276+
fn = xr_function([x, y], z)
277+
z_test = fn(x_test, y_test)
278+
xr_assert_allclose(z_test, expected)
279+
280+
# Test symbolic shapes
281+
x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic
282+
y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic
283+
z = x.dot(y)
284+
fn = xr_function([x, y], z)
285+
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
286+
y_test = DataArray(np.ones((3, 4)), dims=("b", "c"))
287+
z_test = fn(x_test, y_test)
288+
expected = x_test.dot(y_test)
289+
xr_assert_allclose(z_test, expected)
290+
291+
292+
def test_dot_errors():
293+
# No matching dimensions
294+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
295+
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
296+
with pytest.raises(ValueError, match="Dimension e not found in either input"):
297+
x.dot(y, dim="e")
298+
299+
# Concrete dimension size mismatches
300+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
301+
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
302+
with pytest.raises(
303+
ValueError,
304+
match="Size of dim 'b' does not match",
305+
):
306+
x.dot(y)
307+
308+
# Symbolic dimension size mismatches
309+
x = xtensor("x", dims=("a", "b"), shape=(2, None))
310+
y = xtensor("y", dims=("b", "c"), shape=(None, 5))
311+
z = x.dot(y)
312+
fn = xr_function([x, y], z)
313+
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
314+
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
315+
# Doesn't fail until the rewrite
316+
with pytest.raises(ValueError, match="not aligned"):
317+
fn(x_test, y_test)

0 commit comments

Comments
 (0)