Skip to content

Commit f0e8e29

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement dot for XTensorVariables (#1475)
1 parent 7a9db22 commit f0e8e29

File tree

6 files changed

+328
-2
lines changed

6 files changed

+328
-2
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import linalg
5+
from pytensor.xtensor.math import dot
56
from pytensor.xtensor.shape import concat
67
from pytensor.xtensor.type import (
78
as_xtensor,

pytensor/xtensor/math.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import pytensor.scalar as ps
66
from pytensor import config
7+
from pytensor.graph.basic import Apply
78
from pytensor.scalar import ScalarOp
8-
from pytensor.scalar.basic import _cast_mapping
9-
from pytensor.xtensor.basic import as_xtensor
9+
from pytensor.scalar.basic import _cast_mapping, upcast
10+
from pytensor.xtensor.basic import XOp, as_xtensor
11+
from pytensor.xtensor.type import xtensor
1012
from pytensor.xtensor.vectorization import XElemwise
1113

1214

@@ -139,3 +141,110 @@ def cast(x, dtype):
139141
def softmax(x, dim=None):
140142
exp_x = exp(x)
141143
return exp_x / exp_x.sum(dim=dim) # type: ignore
144+
145+
146+
class XDot(XOp):
147+
"""Matrix multiplication between two XTensorVariables.
148+
149+
This operation performs matrix multiplication between two tensors, automatically
150+
aligning and contracting dimensions. The behavior matches xarray's dot operation.
151+
152+
Parameters
153+
----------
154+
dims : tuple of str
155+
The dimensions to contract over. If None, will contract over all matching dimensions.
156+
"""
157+
158+
__props__ = ("dims",)
159+
160+
def __init__(self, dims: Iterable[str]):
161+
self.dims = dims
162+
super().__init__()
163+
164+
def make_node(self, x, y):
165+
x = as_xtensor(x)
166+
y = as_xtensor(y)
167+
168+
x_shape_dict = dict(zip(x.type.dims, x.type.shape))
169+
y_shape_dict = dict(zip(y.type.dims, y.type.shape))
170+
171+
# Check for dimension size mismatches (concrete only)
172+
for dim in self.dims:
173+
x_shape = x_shape_dict.get(dim, None)
174+
y_shape = y_shape_dict.get(dim, None)
175+
if (
176+
isinstance(x_shape, int)
177+
and isinstance(y_shape, int)
178+
and x_shape != y_shape
179+
):
180+
raise ValueError(f"Size of dim '{dim}' does not match")
181+
182+
# Determine output dimensions
183+
shape_dict = {**x_shape_dict, **y_shape_dict}
184+
out_dims = tuple(d for d in shape_dict if d not in self.dims)
185+
186+
# Determine output shape
187+
out_shape = tuple(shape_dict[d] for d in out_dims)
188+
189+
# Determine output dtype
190+
out_dtype = upcast(x.type.dtype, y.type.dtype)
191+
192+
out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
193+
return Apply(self, [x, y], [out])
194+
195+
196+
def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
197+
"""Matrix multiplication between two XTensorVariables.
198+
199+
This operation performs matrix multiplication between two tensors, automatically
200+
aligning and contracting dimensions. The behavior matches xarray's dot operation.
201+
202+
Parameters
203+
----------
204+
x : XTensorVariable
205+
First input tensor
206+
y : XTensorVariable
207+
Second input tensor
208+
dim : str, Iterable[Hashable], EllipsisType, or None, optional
209+
The dimensions to contract over. If None, will contract over all matching dimensions.
210+
If Ellipsis (...), will contract over all dimensions.
211+
212+
Returns
213+
-------
214+
XTensorVariable
215+
The result of the matrix multiplication.
216+
217+
Examples
218+
--------
219+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
220+
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
221+
>>> z = dot(x, y) # Result has dimensions ("a", "c")
222+
>>> z = dot(x, y, dim=...) # Contract over all dimensions
223+
"""
224+
x = as_xtensor(x)
225+
y = as_xtensor(y)
226+
227+
x_dims = set(x.type.dims)
228+
y_dims = set(y.type.dims)
229+
intersection = x_dims & y_dims
230+
union = x_dims | y_dims
231+
232+
# Canonicalize dims
233+
if dim is None:
234+
dim_set = intersection
235+
elif dim is ...:
236+
dim_set = union
237+
elif isinstance(dim, str):
238+
dim_set = {dim}
239+
elif isinstance(dim, Iterable):
240+
dim_set = set(dim)
241+
242+
# Validate provided dims
243+
# Check if any dimension is not found in either input
244+
for d in dim_set:
245+
if d not in union:
246+
raise ValueError(f"Dimension {d} not found in either input")
247+
248+
result = XDot(dims=tuple(dim_set))(x, y)
249+
250+
return result
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytensor.xtensor.rewriting.basic
22
import pytensor.xtensor.rewriting.indexing
3+
import pytensor.xtensor.rewriting.math
34
import pytensor.xtensor.rewriting.reduction
45
import pytensor.xtensor.rewriting.shape
56
import pytensor.xtensor.rewriting.vectorization

pytensor/xtensor/rewriting/math.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from string import ascii_lowercase
2+
3+
from pytensor.graph import node_rewriter
4+
from pytensor.tensor import einsum
5+
from pytensor.tensor.shape import specify_shape
6+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
7+
from pytensor.xtensor.math import XDot
8+
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
9+
10+
11+
@register_lower_xtensor
12+
@node_rewriter(tracks=[XDot])
13+
def lower_dot(fgraph, node):
14+
"""Rewrite XDot to tensor.dot.
15+
16+
This rewrite converts an XDot operation to a tensor-based dot operation,
17+
handling dimension alignment and contraction.
18+
"""
19+
[x, y] = node.inputs
20+
[out] = node.outputs
21+
22+
# Convert inputs to tensors
23+
x_tensor = tensor_from_xtensor(x)
24+
y_tensor = tensor_from_xtensor(y)
25+
26+
# Collect all dimension names across inputs and output
27+
all_dims = list(
28+
dict.fromkeys(x.type.dims + y.type.dims + out.type.dims)
29+
) # preserve order
30+
if len(all_dims) > len(ascii_lowercase):
31+
raise ValueError("Too many dimensions to map to einsum subscripts")
32+
33+
dim_to_char = dict(zip(all_dims, ascii_lowercase))
34+
35+
# Build einsum string
36+
x_subs = "".join(dim_to_char[d] for d in x.type.dims)
37+
y_subs = "".join(dim_to_char[d] for d in y.type.dims)
38+
out_subs = "".join(dim_to_char[d] for d in out.type.dims)
39+
einsum_str = f"{x_subs},{y_subs}->{out_subs}"
40+
41+
# Perform the einsum operation
42+
out_tensor = einsum(einsum_str, x_tensor, y_tensor)
43+
44+
# Reshape to match the output shape
45+
out_tensor = specify_shape(out_tensor, out.type.shape)
46+
47+
return [xtensor_from_tensor(out_tensor, out.type.dims)]

pytensor/xtensor/type.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,10 @@ def stack(self, dim, **dims):
649649
def unstack(self, dim, **dims):
650650
return px.shape.unstack(self, dim, **dims)
651651

652+
def dot(self, other, dim=None):
653+
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
654+
return px.math.dot(self, other, dim=dim)
655+
652656

653657
class XTensorConstantSignature(TensorConstantSignature):
654658
pass

tests/xtensor/test_math.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,167 @@ def test_cast():
150150
yc64 = x.astype("complex64")
151151
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
152152
yc64.astype("float64")
153+
154+
155+
def test_dot():
156+
"""Test basic dot product operations."""
157+
# Test matrix-vector dot product (with multiple-letter dim names)
158+
x = xtensor("x", dims=("aa", "bb"), shape=(2, 3))
159+
y = xtensor("y", dims=("bb",), shape=(3,))
160+
z = x.dot(y)
161+
fn = xr_function([x, y], z)
162+
163+
x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb"))
164+
y_test = DataArray(np.ones(3), dims=("bb",))
165+
z_test = fn(x_test, y_test)
166+
expected = x_test.dot(y_test)
167+
xr_assert_allclose(z_test, expected)
168+
169+
# Test matrix-vector dot product with ellipsis
170+
z = x.dot(y, dim=...)
171+
fn = xr_function([x, y], z)
172+
z_test = fn(x_test, y_test)
173+
expected = x_test.dot(y_test, dim=...)
174+
xr_assert_allclose(z_test, expected)
175+
176+
# Test matrix-matrix dot product
177+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
178+
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
179+
z = x.dot(y)
180+
fn = xr_function([x, y], z)
181+
182+
x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b"))
183+
y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c"))
184+
z_test = fn(x_test, y_test)
185+
expected = x_test.dot(y_test)
186+
xr_assert_allclose(z_test, expected)
187+
188+
# Test matrix-matrix dot product with string dim
189+
z = x.dot(y, dim="b")
190+
fn = xr_function([x, y], z)
191+
z_test = fn(x_test, y_test)
192+
expected = x_test.dot(y_test, dim="b")
193+
xr_assert_allclose(z_test, expected)
194+
195+
# Test matrix-matrix dot product with list of dims
196+
z = x.dot(y, dim=["b"])
197+
fn = xr_function([x, y], z)
198+
z_test = fn(x_test, y_test)
199+
expected = x_test.dot(y_test, dim=["b"])
200+
xr_assert_allclose(z_test, expected)
201+
202+
# Test matrix-matrix dot product with ellipsis
203+
z = x.dot(y, dim=...)
204+
fn = xr_function([x, y], z)
205+
z_test = fn(x_test, y_test)
206+
expected = x_test.dot(y_test, dim=...)
207+
xr_assert_allclose(z_test, expected)
208+
209+
# Test a case where there are two dimensions to sum over
210+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
211+
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
212+
z = x.dot(y)
213+
fn = xr_function([x, y], z)
214+
215+
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
216+
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
217+
z_test = fn(x_test, y_test)
218+
expected = x_test.dot(y_test)
219+
xr_assert_allclose(z_test, expected)
220+
221+
# Same but with explicit dimensions
222+
z = x.dot(y, dim=["b", "c"])
223+
fn = xr_function([x, y], z)
224+
z_test = fn(x_test, y_test)
225+
expected = x_test.dot(y_test, dim=["b", "c"])
226+
xr_assert_allclose(z_test, expected)
227+
228+
# Same but with ellipses
229+
z = x.dot(y, dim=...)
230+
fn = xr_function([x, y], z)
231+
z_test = fn(x_test, y_test)
232+
expected = x_test.dot(y_test, dim=...)
233+
xr_assert_allclose(z_test, expected)
234+
235+
# Dot product with sum
236+
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
237+
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
238+
expected = x_test.dot(y_test, dim=("a", "b", "c"))
239+
240+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
241+
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
242+
z = x.dot(y, dim=("a", "b", "c"))
243+
fn = xr_function([x, y], z)
244+
z_test = fn(x_test, y_test)
245+
xr_assert_allclose(z_test, expected)
246+
247+
# Dot product with sum in the middle
248+
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
249+
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
250+
expected = x_test.dot(y_test, dim=("b", "d"))
251+
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5))
252+
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6))
253+
z = x.dot(y, dim=("b", "d"))
254+
fn = xr_function([x, y], z)
255+
z_test = fn(x_test, y_test)
256+
xr_assert_allclose(z_test, expected)
257+
258+
# Same but with first two dims
259+
expected = x_test.dot(y_test, dim=["a", "b"])
260+
z = x.dot(y, dim=["a", "b"])
261+
fn = xr_function([x, y], z)
262+
z_test = fn(x_test, y_test)
263+
xr_assert_allclose(z_test, expected)
264+
265+
# Same but with last two
266+
expected = x_test.dot(y_test, dim=["d", "e"])
267+
z = x.dot(y, dim=["d", "e"])
268+
fn = xr_function([x, y], z)
269+
z_test = fn(x_test, y_test)
270+
xr_assert_allclose(z_test, expected)
271+
272+
# Same but with every other dim
273+
expected = x_test.dot(y_test, dim=["a", "c", "e"])
274+
z = x.dot(y, dim=["a", "c", "e"])
275+
fn = xr_function([x, y], z)
276+
z_test = fn(x_test, y_test)
277+
xr_assert_allclose(z_test, expected)
278+
279+
# Test symbolic shapes
280+
x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic
281+
y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic
282+
z = x.dot(y)
283+
fn = xr_function([x, y], z)
284+
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
285+
y_test = DataArray(np.ones((3, 4)), dims=("b", "c"))
286+
z_test = fn(x_test, y_test)
287+
expected = x_test.dot(y_test)
288+
xr_assert_allclose(z_test, expected)
289+
290+
291+
def test_dot_errors():
292+
# No matching dimensions
293+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
294+
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
295+
with pytest.raises(ValueError, match="Dimension e not found in either input"):
296+
x.dot(y, dim="e")
297+
298+
# Concrete dimension size mismatches
299+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
300+
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
301+
with pytest.raises(
302+
ValueError,
303+
match="Size of dim 'b' does not match",
304+
):
305+
x.dot(y)
306+
307+
# Symbolic dimension size mismatches
308+
x = xtensor("x", dims=("a", "b"), shape=(2, None))
309+
y = xtensor("y", dims=("b", "c"), shape=(None, 5))
310+
z = x.dot(y)
311+
fn = xr_function([x, y], z)
312+
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
313+
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
314+
# Doesn't fail until the rewrite
315+
with pytest.raises(ValueError, match="not aligned"):
316+
fn(x_test, y_test)

0 commit comments

Comments
 (0)