diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 4fe0ca8106..55953ee452 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -1,12 +1,16 @@ import sys +from collections.abc import Iterable +from types import EllipsisType import numpy as np import pytensor.scalar as ps from pytensor import config +from pytensor.graph.basic import Apply from pytensor.scalar import ScalarOp -from pytensor.scalar.basic import _cast_mapping -from pytensor.xtensor.basic import as_xtensor +from pytensor.scalar.basic import _cast_mapping, upcast +from pytensor.xtensor.basic import XOp, as_xtensor +from pytensor.xtensor.type import xtensor from pytensor.xtensor.vectorization import XElemwise @@ -134,3 +138,110 @@ def cast(x, dtype): if dtype not in _xelemwise_cast_op: _xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype]) return _xelemwise_cast_op[dtype](x) + + +class XDot(XOp): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + dims : tuple of str + The dimensions to contract over. If None, will contract over all matching dimensions. + """ + + __props__ = ("dims",) + + def __init__(self, dims: Iterable[str]): + self.dims = dims + super().__init__() + + def make_node(self, x, y): + x = as_xtensor(x) + y = as_xtensor(y) + + x_shape_dict = dict(zip(x.type.dims, x.type.shape)) + y_shape_dict = dict(zip(y.type.dims, y.type.shape)) + + # Check for dimension size mismatches (concrete only) + for dim in self.dims: + x_shape = x_shape_dict.get(dim, None) + y_shape = y_shape_dict.get(dim, None) + if ( + isinstance(x_shape, int) + and isinstance(y_shape, int) + and x_shape != y_shape + ): + raise ValueError(f"Size of dim '{dim}' does not match") + + # Determine output dimensions + shape_dict = {**x_shape_dict, **y_shape_dict} + out_dims = tuple(d for d in shape_dict if d not in self.dims) + + # Determine output shape + out_shape = tuple(shape_dict[d] for d in out_dims) + + # Determine output dtype + out_dtype = upcast(x.type.dtype, y.type.dtype) + + out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, y], [out]) + + +def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + x : XTensorVariable + First input tensor + y : XTensorVariable + Second input tensor + dim : str, Iterable[Hashable], EllipsisType, or None, optional + The dimensions to contract over. If None, will contract over all matching dimensions. + If Ellipsis (...), will contract over all dimensions. + + Returns + ------- + XTensorVariable + The result of the matrix multiplication. + + Examples + -------- + >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) + >>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) + >>> z = dot(x, y) # Result has dimensions ("a", "c") + >>> z = dot(x, y, dim=...) # Contract over all dimensions + """ + x = as_xtensor(x) + y = as_xtensor(y) + + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + intersection = x_dims & y_dims + union = x_dims | y_dims + + # Canonicalize dims + if dim is None: + dim_set = intersection + elif dim is ...: + dim_set = union + elif isinstance(dim, str): + dim_set = {dim} + elif isinstance(dim, Iterable): + dim_set = set(dim) + + # Validate provided dims + # Check if any dimension is not found in either input + for d in dim_set: + if d not in union: + raise ValueError(f"Dimension {d} not found in either input") + + result = XDot(dims=tuple(dim_set))(x, y) + + return result diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index a65ad0db85..bdbb30f147 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,5 +1,6 @@ import pytensor.xtensor.rewriting.basic import pytensor.xtensor.rewriting.indexing +import pytensor.xtensor.rewriting.math import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py new file mode 100644 index 0000000000..850d91fad3 --- /dev/null +++ b/pytensor/xtensor/rewriting/math.py @@ -0,0 +1,47 @@ +from string import ascii_lowercase + +from pytensor.graph import node_rewriter +from pytensor.tensor import einsum +from pytensor.tensor.shape import specify_shape +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.math import XDot +from pytensor.xtensor.rewriting.utils import register_lower_xtensor + + +@register_lower_xtensor +@node_rewriter(tracks=[XDot]) +def lower_dot(fgraph, node): + """Rewrite XDot to tensor.dot. + + This rewrite converts an XDot operation to a tensor-based dot operation, + handling dimension alignment and contraction. + """ + [x, y] = node.inputs + [out] = node.outputs + + # Convert inputs to tensors + x_tensor = tensor_from_xtensor(x) + y_tensor = tensor_from_xtensor(y) + + # Collect all dimension names across inputs and output + all_dims = list( + dict.fromkeys(x.type.dims + y.type.dims + out.type.dims) + ) # preserve order + if len(all_dims) > len(ascii_lowercase): + raise ValueError("Too many dimensions to map to einsum subscripts") + + dim_to_char = dict(zip(all_dims, ascii_lowercase)) + + # Build einsum string + x_subs = "".join(dim_to_char[d] for d in x.type.dims) + y_subs = "".join(dim_to_char[d] for d in y.type.dims) + out_subs = "".join(dim_to_char[d] for d in out.type.dims) + einsum_str = f"{x_subs},{y_subs}->{out_subs}" + + # Perform the einsum operation + out_tensor = einsum(einsum_str, x_tensor, y_tensor) + + # Reshape to match the output shape + out_tensor = specify_shape(out_tensor, out.type.shape) + + return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 96b0a1fd7c..f16a7432bf 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -650,6 +650,10 @@ def stack(self, dim, **dims): def unstack(self, dim, **dims): return px.shape.unstack(self, dim, **dims) + def dot(self, other, dim=None): + """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" + return px.math.dot(self, other, dim=dim) + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 210bfe9a80..088948de85 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -151,3 +151,167 @@ def test_cast(): yc64 = x.astype("complex64") with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): yc64.astype("float64") + + +def test_dot(): + """Test basic dot product operations.""" + # Test matrix-vector dot product (with multiple-letter dim names) + x = xtensor("x", dims=("aa", "bb"), shape=(2, 3)) + y = xtensor("y", dims=("bb",), shape=(3,)) + z = x.dot(y) + fn = xr_function([x, y], z) + + x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb")) + y_test = DataArray(np.ones(3), dims=("bb",)) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + # Test matrix-vector dot product with ellipsis + z = x.dot(y, dim=...) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) + + # Test matrix-matrix dot product + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + z = x.dot(y) + fn = xr_function([x, y], z) + + x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b")) + y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + # Test matrix-matrix dot product with string dim + z = x.dot(y, dim="b") + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim="b") + xr_assert_allclose(z_test, expected) + + # Test matrix-matrix dot product with list of dims + z = x.dot(y, dim=["b"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=["b"]) + xr_assert_allclose(z_test, expected) + + # Test matrix-matrix dot product with ellipsis + z = x.dot(y, dim=...) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) + + # Test a case where there are two dimensions to sum over + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5)) + z = x.dot(y) + fn = xr_function([x, y], z) + + x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c")) + y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + # Same but with explicit dimensions + z = x.dot(y, dim=["b", "c"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=["b", "c"]) + xr_assert_allclose(z_test, expected) + + # Same but with ellipses + z = x.dot(y, dim=...) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) + + # Dot product with sum + x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c")) + y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d")) + expected = x_test.dot(y_test, dim=("a", "b", "c")) + + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5)) + z = x.dot(y, dim=("a", "b", "c")) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + # Dot product with sum in the middle + x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d")) + y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e")) + expected = x_test.dot(y_test, dim=("b", "d")) + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5)) + y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6)) + z = x.dot(y, dim=("b", "d")) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + # Same but with first two dims + expected = x_test.dot(y_test, dim=["a", "b"]) + z = x.dot(y, dim=["a", "b"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + # Same but with last two + expected = x_test.dot(y_test, dim=["d", "e"]) + z = x.dot(y, dim=["d", "e"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + # Same but with every other dim + expected = x_test.dot(y_test, dim=["a", "c", "e"]) + z = x.dot(y, dim=["a", "c", "e"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + # Test symbolic shapes + x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic + y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic + z = x.dot(y) + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((3, 4)), dims=("b", "c")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + +def test_dot_errors(): + # No matching dimensions + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + with pytest.raises(ValueError, match="Dimension e not found in either input"): + x.dot(y, dim="e") + + # Concrete dimension size mismatches + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(4, 5)) + with pytest.raises( + ValueError, + match="Size of dim 'b' does not match", + ): + x.dot(y) + + # Symbolic dimension size mismatches + x = xtensor("x", dims=("a", "b"), shape=(2, None)) + y = xtensor("y", dims=("b", "c"), shape=(None, 3)) + z = x.dot(y) + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) + # Doesn't fail until the rewrite + with pytest.raises(ValueError, match="not aligned"): + fn(x_test, y_test)