Skip to content

Implement dot for XTensorVariables #1475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
115 changes: 113 additions & 2 deletions pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import sys
from collections.abc import Iterable
from types import EllipsisType

import numpy as np

import pytensor.scalar as ps
from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.scalar import ScalarOp
from pytensor.scalar.basic import _cast_mapping
from pytensor.xtensor.basic import as_xtensor
from pytensor.scalar.basic import _cast_mapping, upcast
from pytensor.xtensor.basic import XOp, as_xtensor
from pytensor.xtensor.type import xtensor
from pytensor.xtensor.vectorization import XElemwise


Expand Down Expand Up @@ -134,3 +138,110 @@ def cast(x, dtype):
if dtype not in _xelemwise_cast_op:
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
return _xelemwise_cast_op[dtype](x)


class XDot(XOp):
"""Matrix multiplication between two XTensorVariables.

This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.

Parameters
----------
dims : tuple of str
The dimensions to contract over. If None, will contract over all matching dimensions.
"""

__props__ = ("dims",)

def __init__(self, dims: Iterable[str]):
self.dims = dims
super().__init__()

def make_node(self, x, y):
x = as_xtensor(x)
y = as_xtensor(y)

x_shape_dict = dict(zip(x.type.dims, x.type.shape))
y_shape_dict = dict(zip(y.type.dims, y.type.shape))

# Check for dimension size mismatches (concrete only)
for dim in self.dims:
x_shape = x_shape_dict.get(dim, None)
y_shape = y_shape_dict.get(dim, None)
if (
isinstance(x_shape, int)
and isinstance(y_shape, int)
and x_shape != y_shape
):
raise ValueError(f"Size of dim '{dim}' does not match")

# Determine output dimensions
shape_dict = {**x_shape_dict, **y_shape_dict}
out_dims = tuple(d for d in shape_dict if d not in self.dims)

# Determine output shape
out_shape = tuple(shape_dict[d] for d in out_dims)

# Determine output dtype
out_dtype = upcast(x.type.dtype, y.type.dtype)

out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, y], [out])


def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna merge because I need it for testing, but the xarray dot allows arbitrary number of operands, which we should try to support eventually. Will open an issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be an easy implementation. we can come back to it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Matrix multiplication between two XTensorVariables.

This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.

Parameters
----------
x : XTensorVariable
First input tensor
y : XTensorVariable
Second input tensor
dim : str, Iterable[Hashable], EllipsisType, or None, optional
The dimensions to contract over. If None, will contract over all matching dimensions.
If Ellipsis (...), will contract over all dimensions.

Returns
-------
XTensorVariable
The result of the matrix multiplication.

Examples
--------
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
>>> z = dot(x, y) # Result has dimensions ("a", "c")
>>> z = dot(x, y, dim=...) # Contract over all dimensions
"""
x = as_xtensor(x)
y = as_xtensor(y)

x_dims = set(x.type.dims)
y_dims = set(y.type.dims)
intersection = x_dims & y_dims
union = x_dims | y_dims

# Canonicalize dims
if dim is None:
dim_set = intersection
elif dim is ...:
dim_set = union
elif isinstance(dim, str):
dim_set = {dim}
elif isinstance(dim, Iterable):
dim_set = set(dim)

# Validate provided dims
# Check if any dimension is not found in either input
for d in dim_set:
if d not in union:
raise ValueError(f"Dimension {d} not found in either input")

result = XDot(dims=tuple(dim_set))(x, y)

return result
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.indexing
import pytensor.xtensor.rewriting.math
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
47 changes: 47 additions & 0 deletions pytensor/xtensor/rewriting/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from string import ascii_lowercase

from pytensor.graph import node_rewriter
from pytensor.tensor import einsum
from pytensor.tensor.shape import specify_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_lower_xtensor
@node_rewriter(tracks=[XDot])
def lower_dot(fgraph, node):
"""Rewrite XDot to tensor.dot.

This rewrite converts an XDot operation to a tensor-based dot operation,
handling dimension alignment and contraction.
"""
[x, y] = node.inputs
[out] = node.outputs

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
y_tensor = tensor_from_xtensor(y)

# Collect all dimension names across inputs and output
all_dims = list(
dict.fromkeys(x.type.dims + y.type.dims + out.type.dims)
) # preserve order
if len(all_dims) > len(ascii_lowercase):
raise ValueError("Too many dimensions to map to einsum subscripts")

dim_to_char = dict(zip(all_dims, ascii_lowercase))

# Build einsum string
x_subs = "".join(dim_to_char[d] for d in x.type.dims)
y_subs = "".join(dim_to_char[d] for d in y.type.dims)
out_subs = "".join(dim_to_char[d] for d in out.type.dims)
einsum_str = f"{x_subs},{y_subs}->{out_subs}"

# Perform the einsum operation
out_tensor = einsum(einsum_str, x_tensor, y_tensor)

# Reshape to match the output shape
out_tensor = specify_shape(out_tensor, out.type.shape)

return [xtensor_from_tensor(out_tensor, out.type.dims)]
4 changes: 4 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,10 @@ def stack(self, dim, **dims):
def unstack(self, dim, **dims):
return px.shape.unstack(self, dim, **dims)

def dot(self, other, dim=None):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return px.math.dot(self, other, dim=dim)


class XTensorConstantSignature(tuple):
def __eq__(self, other):
Expand Down
164 changes: 164 additions & 0 deletions tests/xtensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,167 @@ def test_cast():
yc64 = x.astype("complex64")
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
yc64.astype("float64")


def test_dot():
"""Test basic dot product operations."""
# Test matrix-vector dot product (with multiple-letter dim names)
x = xtensor("x", dims=("aa", "bb"), shape=(2, 3))
y = xtensor("y", dims=("bb",), shape=(3,))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb"))
y_test = DataArray(np.ones(3), dims=("bb",))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Test matrix-vector dot product with ellipsis
z = x.dot(y, dim=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b"))
y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with string dim
z = x.dot(y, dim="b")
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim="b")
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with list of dims
z = x.dot(y, dim=["b"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=["b"])
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with ellipsis
z = x.dot(y, dim=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Test a case where there are two dimensions to sum over
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Same but with explicit dimensions
z = x.dot(y, dim=["b", "c"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=["b", "c"])
xr_assert_allclose(z_test, expected)

# Same but with ellipses
z = x.dot(y, dim=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Dot product with sum
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
expected = x_test.dot(y_test, dim=("a", "b", "c"))

x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
z = x.dot(y, dim=("a", "b", "c"))
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Dot product with sum in the middle
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
expected = x_test.dot(y_test, dim=("b", "d"))
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5))
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6))
z = x.dot(y, dim=("b", "d"))
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with first two dims
expected = x_test.dot(y_test, dim=["a", "b"])
z = x.dot(y, dim=["a", "b"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with last two
expected = x_test.dot(y_test, dim=["d", "e"])
z = x.dot(y, dim=["d", "e"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with every other dim
expected = x_test.dot(y_test, dim=["a", "c", "e"])
z = x.dot(y, dim=["a", "c", "e"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Test symbolic shapes
x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic
y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic
z = x.dot(y)
fn = xr_function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((3, 4)), dims=("b", "c"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)


def test_dot_errors():
# No matching dimensions
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
with pytest.raises(ValueError, match="Dimension e not found in either input"):
x.dot(y, dim="e")

# Concrete dimension size mismatches
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
with pytest.raises(
ValueError,
match="Size of dim 'b' does not match",
):
x.dot(y)

# Symbolic dimension size mismatches
x = xtensor("x", dims=("a", "b"), shape=(2, None))
y = xtensor("y", dims=("b", "c"), shape=(None, 3))
z = x.dot(y)
fn = xr_function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
# Doesn't fail until the rewrite
with pytest.raises(ValueError, match="not aligned"):
fn(x_test, y_test)