From 425e3ccf7784199a920bb3de757fa3adba2ace2b Mon Sep 17 00:00:00 2001
From: Allen Downey <downey@allendowney.com>
Date: Fri, 6 Jun 2025 14:46:06 -0400
Subject: [PATCH] Adding dot for xtensor

---
 pytensor/xtensor/math.py               | 118 ++++++++++++++++++++++++-
 pytensor/xtensor/rewriting/__init__.py |   1 +
 pytensor/xtensor/rewriting/math.py     |  40 +++++++++
 pytensor/xtensor/type.py               |   4 +
 tests/xtensor/test_math.py             |  31 +++++++
 5 files changed, 192 insertions(+), 2 deletions(-)
 create mode 100644 pytensor/xtensor/rewriting/math.py

diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py
index 4fe0ca8106..9e88f99f25 100644
--- a/pytensor/xtensor/math.py
+++ b/pytensor/xtensor/math.py
@@ -4,9 +4,11 @@
 
 import pytensor.scalar as ps
 from pytensor import config
+from pytensor.graph.basic import Apply
 from pytensor.scalar import ScalarOp
-from pytensor.scalar.basic import _cast_mapping
-from pytensor.xtensor.basic import as_xtensor
+from pytensor.scalar.basic import _cast_mapping, upcast
+from pytensor.xtensor.basic import XOp, as_xtensor
+from pytensor.xtensor.type import xtensor
 from pytensor.xtensor.vectorization import XElemwise
 
 
@@ -134,3 +136,115 @@ def cast(x, dtype):
     if dtype not in _xelemwise_cast_op:
         _xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
     return _xelemwise_cast_op[dtype](x)
+
+
+class XDot(XOp):
+    """Matrix multiplication between two XTensorVariables.
+
+    This operation performs matrix multiplication between two tensors, automatically
+    aligning and contracting dimensions. The behavior matches xarray's dot operation.
+
+    Parameters
+    ----------
+    dims : tuple of str
+        The dimensions to contract over. If None, will contract over all matching dimensions.
+    """
+
+    __props__ = ("dims",)
+
+    def __init__(self, dims: tuple[str, ...] | None = None):
+        self.dims = dims
+        super().__init__()
+
+    def make_node(self, x, y):
+        x = as_xtensor(x)
+        y = as_xtensor(y)
+
+        # Get dimensions to contract
+        if self.dims is None:
+            # Contract over all matching dimensions
+            x_dims = set(x.type.dims)
+            y_dims = set(y.type.dims)
+            contract_dims = tuple(x_dims & y_dims)
+        else:
+            contract_dims = self.dims
+
+        # Determine output dimensions and shapes
+        x_dims = list(x.type.dims)
+        y_dims = list(y.type.dims)
+        x_shape = list(x.type.shape)
+        y_shape = list(y.type.shape)
+
+        # Remove contracted dimensions
+        for dim in contract_dims:
+            x_idx = x_dims.index(dim)
+            y_idx = y_dims.index(dim)
+            x_dims.pop(x_idx)
+            y_dims.pop(y_idx)
+            x_shape.pop(x_idx)
+            y_shape.pop(y_idx)
+
+        # Combine remaining dimensions
+        out_dims = tuple(x_dims + y_dims)
+        out_shape = tuple(x_shape + y_shape)
+
+        # Determine output dtype
+        out_dtype = upcast(x.type.dtype, y.type.dtype)
+
+        out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
+        return Apply(self, [x, y], [out])
+
+
+def dot(x, y, dims: tuple[str, ...] | None = None):
+    """Matrix multiplication between two XTensorVariables.
+
+    This operation performs matrix multiplication between two tensors, automatically
+    aligning and contracting dimensions. The behavior matches xarray's dot operation.
+
+    Parameters
+    ----------
+    x : XTensorVariable
+        First input tensor
+    y : XTensorVariable
+        Second input tensor
+    dims : tuple of str, optional
+        The dimensions to contract over. If None, will contract over all matching dimensions.
+
+    Returns
+    -------
+    XTensorVariable
+        The result of the matrix multiplication.
+
+    Examples
+    --------
+    >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
+    >>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
+    >>> z = dot(x, y)  # Result has dimensions ("a", "c")
+    """
+    x = as_xtensor(x)
+    y = as_xtensor(y)
+
+    # Validate dimensions if specified
+    if dims is not None:
+        if not isinstance(dims, tuple):
+            dims = tuple(dims)
+        for dim in dims:
+            if dim not in x.type.dims:
+                raise ValueError(
+                    f"Dimension {dim} not found in first input {x.type.dims}"
+                )
+            if dim not in y.type.dims:
+                raise ValueError(
+                    f"Dimension {dim} not found in second input {y.type.dims}"
+                )
+            # Check for compatible shapes in contracted dimensions
+            x_idx = x.type.dims.index(dim)
+            y_idx = y.type.dims.index(dim)
+            x_size = x.type.shape[x_idx]
+            y_size = y.type.shape[y_idx]
+            if x_size is not None and y_size is not None and x_size != y_size:
+                raise ValueError(
+                    f"Dimension {dim} has incompatible shapes: {x_size} and {y_size}"
+                )
+
+    return XDot(dims=dims)(x, y)
diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py
index a65ad0db85..bdbb30f147 100644
--- a/pytensor/xtensor/rewriting/__init__.py
+++ b/pytensor/xtensor/rewriting/__init__.py
@@ -1,5 +1,6 @@
 import pytensor.xtensor.rewriting.basic
 import pytensor.xtensor.rewriting.indexing
+import pytensor.xtensor.rewriting.math
 import pytensor.xtensor.rewriting.reduction
 import pytensor.xtensor.rewriting.shape
 import pytensor.xtensor.rewriting.vectorization
diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py
new file mode 100644
index 0000000000..0384276fed
--- /dev/null
+++ b/pytensor/xtensor/rewriting/math.py
@@ -0,0 +1,40 @@
+from pytensor.graph import node_rewriter
+from pytensor.tensor import tensordot
+from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
+from pytensor.xtensor.math import XDot
+from pytensor.xtensor.rewriting.utils import register_lower_xtensor
+
+
+@register_lower_xtensor
+@node_rewriter(tracks=[XDot])
+def lower_dot(fgraph, node):
+    """Rewrite XDot to tensor.dot.
+
+    This rewrite converts an XDot operation to a tensor-based dot operation,
+    handling dimension alignment and contraction.
+    """
+    [x, y] = node.inputs
+    [out] = node.outputs
+
+    # Convert inputs to tensors
+    x_tensor = tensor_from_xtensor(x)
+    y_tensor = tensor_from_xtensor(y)
+
+    # Get dimensions to contract
+    if node.op.dims is None:
+        # Contract over all matching dimensions
+        x_dims = set(x.type.dims)
+        y_dims = set(y.type.dims)
+        contract_dims = tuple(x_dims & y_dims)
+    else:
+        contract_dims = node.op.dims
+
+    # Get axes to contract for each input
+    x_axes = [x.type.dims.index(dim) for dim in contract_dims]
+    y_axes = [y.type.dims.index(dim) for dim in contract_dims]
+
+    # Perform dot product
+    out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes))
+
+    # Convert back to xtensor
+    return [xtensor_from_tensor(out_tensor, out.type.dims)]
diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py
index fd601df018..e94347d191 100644
--- a/pytensor/xtensor/type.py
+++ b/pytensor/xtensor/type.py
@@ -593,6 +593,10 @@ def stack(self, dim, **dims):
     def unstack(self, dim, **dims):
         return px.shape.unstack(self, dim, **dims)
 
+    def dot(self, other, dims=None):
+        """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
+        return px.math.dot(self, other, dims=dims)
+
 
 class XTensorConstantSignature(tuple):
     def __eq__(self, other):
diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py
index 210bfe9a80..50d9b1cccf 100644
--- a/tests/xtensor/test_math.py
+++ b/tests/xtensor/test_math.py
@@ -151,3 +151,34 @@ def test_cast():
     yc64 = x.astype("complex64")
     with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
         yc64.astype("float64")
+
+
+def test_dot():
+    """Test basic dot product operations."""
+    # Test matrix-matrix dot product
+    x = xtensor("x", dims=("a", "b"), shape=(2, 3))
+    y = xtensor("y", dims=("b", "c"), shape=(3, 4))
+    z = x.dot(y)
+    assert z.type.dims == ("a", "c")
+    assert z.type.shape == (2, 4)
+
+    fn = xr_function([x, y], z)
+    x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
+    y_test = DataArray(np.ones((3, 4)), dims=("b", "c"))
+    z_test = fn(x_test, y_test)
+    expected = x_test.dot(y_test)
+    xr_assert_allclose(z_test, expected)
+
+    # Test matrix-vector dot product
+    x = xtensor("x", dims=("a", "b"), shape=(2, 3))
+    y = xtensor("y", dims=("b",), shape=(3,))
+    z = x.dot(y)
+    assert z.type.dims == ("a",)
+    assert z.type.shape == (2,)
+
+    fn = xr_function([x, y], z)
+    x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
+    y_test = DataArray(np.ones(3), dims=("b",))
+    z_test = fn(x_test, y_test)
+    expected = x_test.dot(y_test)
+    xr_assert_allclose(z_test, expected)