diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 4fe0ca8106..9e88f99f25 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -4,9 +4,11 @@ import pytensor.scalar as ps from pytensor import config +from pytensor.graph.basic import Apply from pytensor.scalar import ScalarOp -from pytensor.scalar.basic import _cast_mapping -from pytensor.xtensor.basic import as_xtensor +from pytensor.scalar.basic import _cast_mapping, upcast +from pytensor.xtensor.basic import XOp, as_xtensor +from pytensor.xtensor.type import xtensor from pytensor.xtensor.vectorization import XElemwise @@ -134,3 +136,115 @@ def cast(x, dtype): if dtype not in _xelemwise_cast_op: _xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype]) return _xelemwise_cast_op[dtype](x) + + +class XDot(XOp): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + dims : tuple of str + The dimensions to contract over. If None, will contract over all matching dimensions. + """ + + __props__ = ("dims",) + + def __init__(self, dims: tuple[str, ...] | None = None): + self.dims = dims + super().__init__() + + def make_node(self, x, y): + x = as_xtensor(x) + y = as_xtensor(y) + + # Get dimensions to contract + if self.dims is None: + # Contract over all matching dimensions + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + contract_dims = tuple(x_dims & y_dims) + else: + contract_dims = self.dims + + # Determine output dimensions and shapes + x_dims = list(x.type.dims) + y_dims = list(y.type.dims) + x_shape = list(x.type.shape) + y_shape = list(y.type.shape) + + # Remove contracted dimensions + for dim in contract_dims: + x_idx = x_dims.index(dim) + y_idx = y_dims.index(dim) + x_dims.pop(x_idx) + y_dims.pop(y_idx) + x_shape.pop(x_idx) + y_shape.pop(y_idx) + + # Combine remaining dimensions + out_dims = tuple(x_dims + y_dims) + out_shape = tuple(x_shape + y_shape) + + # Determine output dtype + out_dtype = upcast(x.type.dtype, y.type.dtype) + + out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, y], [out]) + + +def dot(x, y, dims: tuple[str, ...] | None = None): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + x : XTensorVariable + First input tensor + y : XTensorVariable + Second input tensor + dims : tuple of str, optional + The dimensions to contract over. If None, will contract over all matching dimensions. + + Returns + ------- + XTensorVariable + The result of the matrix multiplication. + + Examples + -------- + >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) + >>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) + >>> z = dot(x, y) # Result has dimensions ("a", "c") + """ + x = as_xtensor(x) + y = as_xtensor(y) + + # Validate dimensions if specified + if dims is not None: + if not isinstance(dims, tuple): + dims = tuple(dims) + for dim in dims: + if dim not in x.type.dims: + raise ValueError( + f"Dimension {dim} not found in first input {x.type.dims}" + ) + if dim not in y.type.dims: + raise ValueError( + f"Dimension {dim} not found in second input {y.type.dims}" + ) + # Check for compatible shapes in contracted dimensions + x_idx = x.type.dims.index(dim) + y_idx = y.type.dims.index(dim) + x_size = x.type.shape[x_idx] + y_size = y.type.shape[y_idx] + if x_size is not None and y_size is not None and x_size != y_size: + raise ValueError( + f"Dimension {dim} has incompatible shapes: {x_size} and {y_size}" + ) + + return XDot(dims=dims)(x, y) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index a65ad0db85..bdbb30f147 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,5 +1,6 @@ import pytensor.xtensor.rewriting.basic import pytensor.xtensor.rewriting.indexing +import pytensor.xtensor.rewriting.math import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py new file mode 100644 index 0000000000..0384276fed --- /dev/null +++ b/pytensor/xtensor/rewriting/math.py @@ -0,0 +1,40 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import tensordot +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.math import XDot +from pytensor.xtensor.rewriting.utils import register_lower_xtensor + + +@register_lower_xtensor +@node_rewriter(tracks=[XDot]) +def lower_dot(fgraph, node): + """Rewrite XDot to tensor.dot. + + This rewrite converts an XDot operation to a tensor-based dot operation, + handling dimension alignment and contraction. + """ + [x, y] = node.inputs + [out] = node.outputs + + # Convert inputs to tensors + x_tensor = tensor_from_xtensor(x) + y_tensor = tensor_from_xtensor(y) + + # Get dimensions to contract + if node.op.dims is None: + # Contract over all matching dimensions + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + contract_dims = tuple(x_dims & y_dims) + else: + contract_dims = node.op.dims + + # Get axes to contract for each input + x_axes = [x.type.dims.index(dim) for dim in contract_dims] + y_axes = [y.type.dims.index(dim) for dim in contract_dims] + + # Perform dot product + out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes)) + + # Convert back to xtensor + return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index fd601df018..e94347d191 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -593,6 +593,10 @@ def stack(self, dim, **dims): def unstack(self, dim, **dims): return px.shape.unstack(self, dim, **dims) + def dot(self, other, dims=None): + """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" + return px.math.dot(self, other, dims=dims) + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 210bfe9a80..50d9b1cccf 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -151,3 +151,34 @@ def test_cast(): yc64 = x.astype("complex64") with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): yc64.astype("float64") + + +def test_dot(): + """Test basic dot product operations.""" + # Test matrix-matrix dot product + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + z = x.dot(y) + assert z.type.dims == ("a", "c") + assert z.type.shape == (2, 4) + + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((3, 4)), dims=("b", "c")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + # Test matrix-vector dot product + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b",), shape=(3,)) + z = x.dot(y) + assert z.type.dims == ("a",) + assert z.type.shape == (2,) + + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones(3), dims=("b",)) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected)