From 1a4755ec347bf483060c8e87d2c5e9f2f91068a1 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 6 Jun 2025 14:46:06 -0400 Subject: [PATCH 01/12] Adding dot for xtensor --- pytensor/xtensor/math.py | 118 ++++++++++++++++++++++++- pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/math.py | 40 +++++++++ pytensor/xtensor/type.py | 4 + tests/xtensor/test_math.py | 31 +++++++ 5 files changed, 192 insertions(+), 2 deletions(-) create mode 100644 pytensor/xtensor/rewriting/math.py diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 4fe0ca8106..9e88f99f25 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -4,9 +4,11 @@ import pytensor.scalar as ps from pytensor import config +from pytensor.graph.basic import Apply from pytensor.scalar import ScalarOp -from pytensor.scalar.basic import _cast_mapping -from pytensor.xtensor.basic import as_xtensor +from pytensor.scalar.basic import _cast_mapping, upcast +from pytensor.xtensor.basic import XOp, as_xtensor +from pytensor.xtensor.type import xtensor from pytensor.xtensor.vectorization import XElemwise @@ -134,3 +136,115 @@ def cast(x, dtype): if dtype not in _xelemwise_cast_op: _xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype]) return _xelemwise_cast_op[dtype](x) + + +class XDot(XOp): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + dims : tuple of str + The dimensions to contract over. If None, will contract over all matching dimensions. + """ + + __props__ = ("dims",) + + def __init__(self, dims: tuple[str, ...] | None = None): + self.dims = dims + super().__init__() + + def make_node(self, x, y): + x = as_xtensor(x) + y = as_xtensor(y) + + # Get dimensions to contract + if self.dims is None: + # Contract over all matching dimensions + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + contract_dims = tuple(x_dims & y_dims) + else: + contract_dims = self.dims + + # Determine output dimensions and shapes + x_dims = list(x.type.dims) + y_dims = list(y.type.dims) + x_shape = list(x.type.shape) + y_shape = list(y.type.shape) + + # Remove contracted dimensions + for dim in contract_dims: + x_idx = x_dims.index(dim) + y_idx = y_dims.index(dim) + x_dims.pop(x_idx) + y_dims.pop(y_idx) + x_shape.pop(x_idx) + y_shape.pop(y_idx) + + # Combine remaining dimensions + out_dims = tuple(x_dims + y_dims) + out_shape = tuple(x_shape + y_shape) + + # Determine output dtype + out_dtype = upcast(x.type.dtype, y.type.dtype) + + out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, y], [out]) + + +def dot(x, y, dims: tuple[str, ...] | None = None): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + x : XTensorVariable + First input tensor + y : XTensorVariable + Second input tensor + dims : tuple of str, optional + The dimensions to contract over. If None, will contract over all matching dimensions. + + Returns + ------- + XTensorVariable + The result of the matrix multiplication. + + Examples + -------- + >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) + >>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) + >>> z = dot(x, y) # Result has dimensions ("a", "c") + """ + x = as_xtensor(x) + y = as_xtensor(y) + + # Validate dimensions if specified + if dims is not None: + if not isinstance(dims, tuple): + dims = tuple(dims) + for dim in dims: + if dim not in x.type.dims: + raise ValueError( + f"Dimension {dim} not found in first input {x.type.dims}" + ) + if dim not in y.type.dims: + raise ValueError( + f"Dimension {dim} not found in second input {y.type.dims}" + ) + # Check for compatible shapes in contracted dimensions + x_idx = x.type.dims.index(dim) + y_idx = y.type.dims.index(dim) + x_size = x.type.shape[x_idx] + y_size = y.type.shape[y_idx] + if x_size is not None and y_size is not None and x_size != y_size: + raise ValueError( + f"Dimension {dim} has incompatible shapes: {x_size} and {y_size}" + ) + + return XDot(dims=dims)(x, y) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index a65ad0db85..bdbb30f147 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,5 +1,6 @@ import pytensor.xtensor.rewriting.basic import pytensor.xtensor.rewriting.indexing +import pytensor.xtensor.rewriting.math import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py new file mode 100644 index 0000000000..0384276fed --- /dev/null +++ b/pytensor/xtensor/rewriting/math.py @@ -0,0 +1,40 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import tensordot +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.math import XDot +from pytensor.xtensor.rewriting.utils import register_lower_xtensor + + +@register_lower_xtensor +@node_rewriter(tracks=[XDot]) +def lower_dot(fgraph, node): + """Rewrite XDot to tensor.dot. + + This rewrite converts an XDot operation to a tensor-based dot operation, + handling dimension alignment and contraction. + """ + [x, y] = node.inputs + [out] = node.outputs + + # Convert inputs to tensors + x_tensor = tensor_from_xtensor(x) + y_tensor = tensor_from_xtensor(y) + + # Get dimensions to contract + if node.op.dims is None: + # Contract over all matching dimensions + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + contract_dims = tuple(x_dims & y_dims) + else: + contract_dims = node.op.dims + + # Get axes to contract for each input + x_axes = [x.type.dims.index(dim) for dim in contract_dims] + y_axes = [y.type.dims.index(dim) for dim in contract_dims] + + # Perform dot product + out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes)) + + # Convert back to xtensor + return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 96b0a1fd7c..385ed3b494 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -650,6 +650,10 @@ def stack(self, dim, **dims): def unstack(self, dim, **dims): return px.shape.unstack(self, dim, **dims) + def dot(self, other, dims=None): + """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" + return px.math.dot(self, other, dims=dims) + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 210bfe9a80..50d9b1cccf 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -151,3 +151,34 @@ def test_cast(): yc64 = x.astype("complex64") with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): yc64.astype("float64") + + +def test_dot(): + """Test basic dot product operations.""" + # Test matrix-matrix dot product + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + z = x.dot(y) + assert z.type.dims == ("a", "c") + assert z.type.shape == (2, 4) + + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((3, 4)), dims=("b", "c")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + # Test matrix-vector dot product + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b",), shape=(3,)) + z = x.dot(y) + assert z.type.dims == ("a",) + assert z.type.shape == (2,) + + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones(3), dims=("b",)) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) From 28796dcb70a457f935d269af9f179b8f57f60b91 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sat, 14 Jun 2025 15:15:24 -0400 Subject: [PATCH 02/12] First pass at xtensor.dot --- pytensor/xtensor/math.py | 96 ++++++++++++++---------------- pytensor/xtensor/rewriting/math.py | 18 +++--- tests/xtensor/test_math.py | 81 +++++++++++++++++++++---- 3 files changed, 122 insertions(+), 73 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 9e88f99f25..093abf2101 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -1,4 +1,6 @@ import sys +from types import EllipsisType +from typing import Hashable, Iterable import numpy as np @@ -148,45 +150,36 @@ class XDot(XOp): ---------- dims : tuple of str The dimensions to contract over. If None, will contract over all matching dimensions. + sum_result : bool + If True, sum over all remaining axes after contraction (for full contraction, e.g. dims=...). """ - __props__ = ("dims",) + __props__ = ("dims", "sum_result") - def __init__(self, dims: tuple[str, ...] | None = None): + def __init__(self, dims: Iterable[str], sum_result: bool = False): self.dims = dims + self.sum_result = sum_result super().__init__() def make_node(self, x, y): x = as_xtensor(x) y = as_xtensor(y) - # Get dimensions to contract - if self.dims is None: - # Contract over all matching dimensions - x_dims = set(x.type.dims) - y_dims = set(y.type.dims) - contract_dims = tuple(x_dims & y_dims) - else: - contract_dims = self.dims - - # Determine output dimensions and shapes - x_dims = list(x.type.dims) - y_dims = list(y.type.dims) - x_shape = list(x.type.shape) - y_shape = list(y.type.shape) - - # Remove contracted dimensions - for dim in contract_dims: - x_idx = x_dims.index(dim) - y_idx = y_dims.index(dim) - x_dims.pop(x_idx) - y_dims.pop(y_idx) - x_shape.pop(x_idx) - y_shape.pop(y_idx) + # Filter out contracted dimensions + x_dims = [dim for dim in x.type.dims if dim not in self.dims] + y_dims = [dim for dim in y.type.dims if dim not in self.dims] + x_shape = [size for dim, size in zip(x.type.dims, x.type.shape) + if dim not in self.dims] + y_shape = [size for dim, size in zip(y.type.dims, y.type.shape) + if dim not in self.dims] # Combine remaining dimensions - out_dims = tuple(x_dims + y_dims) - out_shape = tuple(x_shape + y_shape) + if self.sum_result: + out_dims = () + out_shape = () + else: + out_dims = tuple(x_dims + y_dims) + out_shape = tuple(x_shape + y_shape) # Determine output dtype out_dtype = upcast(x.type.dtype, y.type.dtype) @@ -195,7 +188,7 @@ def make_node(self, x, y): return Apply(self, [x, y], [out]) -def dot(x, y, dims: tuple[str, ...] | None = None): +def dot(x, y, dims: str | Iterable[Hashable] | EllipsisType | None = None): """Matrix multiplication between two XTensorVariables. This operation performs matrix multiplication between two tensors, automatically @@ -207,8 +200,9 @@ def dot(x, y, dims: tuple[str, ...] | None = None): First input tensor y : XTensorVariable Second input tensor - dims : tuple of str, optional + dims : str, Iterable[Hashable], EllipsisType, or None, optional The dimensions to contract over. If None, will contract over all matching dimensions. + If Ellipsis (...), will contract over all dimensions. Returns ------- @@ -220,31 +214,33 @@ def dot(x, y, dims: tuple[str, ...] | None = None): >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) >>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) >>> z = dot(x, y) # Result has dimensions ("a", "c") + >>> z = dot(x, y, dim=...) # Contract over all dimensions """ x = as_xtensor(x) y = as_xtensor(y) - # Validate dimensions if specified - if dims is not None: - if not isinstance(dims, tuple): - dims = tuple(dims) + # Canonicalize dims + if isinstance(dims, str): + dims = (dims,) + elif isinstance(dims, Iterable): + dims = tuple(dims) + + # Validate provided dims + if isinstance(dims, Iterable): for dim in dims: if dim not in x.type.dims: - raise ValueError( - f"Dimension {dim} not found in first input {x.type.dims}" - ) + raise ValueError(f"Dimension {dim} not found in first input {x.type.dims}") if dim not in y.type.dims: - raise ValueError( - f"Dimension {dim} not found in second input {y.type.dims}" - ) - # Check for compatible shapes in contracted dimensions - x_idx = x.type.dims.index(dim) - y_idx = y.type.dims.index(dim) - x_size = x.type.shape[x_idx] - y_size = y.type.shape[y_idx] - if x_size is not None and y_size is not None and x_size != y_size: - raise ValueError( - f"Dimension {dim} has incompatible shapes: {x_size} and {y_size}" - ) - - return XDot(dims=dims)(x, y) + raise ValueError(f"Dimension {dim} not found in second input {y.type.dims}") + + # If dims is ... , we have to sum over all remaining axes + sum_result = dims is ... + + # Handle None and ... cases + if dims is None or dims is ...: + # Contract over all matching dimensions + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + dims = tuple(x_dims & y_dims) + + return XDot(dims=dims, sum_result=sum_result)(x, y) diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index 0384276fed..51e940442e 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -20,21 +20,17 @@ def lower_dot(fgraph, node): x_tensor = tensor_from_xtensor(x) y_tensor = tensor_from_xtensor(y) - # Get dimensions to contract - if node.op.dims is None: - # Contract over all matching dimensions - x_dims = set(x.type.dims) - y_dims = set(y.type.dims) - contract_dims = tuple(x_dims & y_dims) - else: - contract_dims = node.op.dims - # Get axes to contract for each input - x_axes = [x.type.dims.index(dim) for dim in contract_dims] - y_axes = [y.type.dims.index(dim) for dim in contract_dims] + x_axes = [x.type.dims.index(dim) for dim in node.op.dims] + y_axes = [y.type.dims.index(dim) for dim in node.op.dims] # Perform dot product out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes)) + # Sum over all remaining axes if needed + if node.op.sum_result: + # Sum over all remaining dimensions + out_tensor = out_tensor.sum(axis=None) + # Convert back to xtensor return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 50d9b1cccf..9d8d7f4112 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -155,30 +155,87 @@ def test_cast(): def test_dot(): """Test basic dot product operations.""" - # Test matrix-matrix dot product + # Test matrix-vector dot product x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + y = xtensor("y", dims=("b",), shape=(3,)) z = x.dot(y) - assert z.type.dims == ("a", "c") - assert z.type.shape == (2, 4) - fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) - y_test = DataArray(np.ones((3, 4)), dims=("b", "c")) + y_test = DataArray(np.ones(3), dims=("b",)) z_test = fn(x_test, y_test) expected = x_test.dot(y_test) xr_assert_allclose(z_test, expected) - # Test matrix-vector dot product + # Test matrix-matrix dot product x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = xtensor("y", dims=("b",), shape=(3,)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) z = x.dot(y) - assert z.type.dims == ("a",) - assert z.type.shape == (2,) + fn = xr_function([x, y], z) + # Use outer product to create test data with diverse values + x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), + dims=("a", "b")) + y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), + dims=("b", "c")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + # Test matrix-matrix dot product with string dims + z = x.dot(y, dims="b") fn = xr_function([x, y], z) - x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) - y_test = DataArray(np.ones(3), dims=("b",)) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim="b") + xr_assert_allclose(z_test, expected) + + # Test matrix-matrix dot product with list of dims + z = x.dot(y, dims=["b"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=["b"]) + xr_assert_allclose(z_test, expected) + + # Test matrix-matrix dot product with ellipsis + if True: + z = x.dot(y, dims=...) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) + + # Test a case where there are two dimensions to contract over + x = xtensor("x", dims=("a", "b", 'c'), shape=(2, 3, 4)) + y = xtensor("y", dims=("b", "c", 'd'), shape=(3, 4, 5)) + z = x.dot(y) + fn = xr_function([x, y], z) + + x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c")) + y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d")) z_test = fn(x_test, y_test) expected = x_test.dot(y_test) xr_assert_allclose(z_test, expected) + + # Same but with explicit dimensions + z = x.dot(y, dims=["b", "c"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=["b", "c"]) + xr_assert_allclose(z_test, expected) + + # Same but with ellipses + if True: + z = x.dot(y, dims=...) + fn = xr_function([x, y], z) + + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) + + + + + + + + \ No newline at end of file From 7c9214dba316f8712e0475393f55548f66ddb650 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sat, 14 Jun 2025 15:18:54 -0400 Subject: [PATCH 03/12] Lint --- pytensor/xtensor/math.py | 24 +++++++++++++++--------- tests/xtensor/test_math.py | 22 ++++++---------------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 093abf2101..2f8d946254 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -1,6 +1,6 @@ import sys +from collections.abc import Hashable, Iterable from types import EllipsisType -from typing import Hashable, Iterable import numpy as np @@ -168,10 +168,12 @@ def make_node(self, x, y): # Filter out contracted dimensions x_dims = [dim for dim in x.type.dims if dim not in self.dims] y_dims = [dim for dim in y.type.dims if dim not in self.dims] - x_shape = [size for dim, size in zip(x.type.dims, x.type.shape) - if dim not in self.dims] - y_shape = [size for dim, size in zip(y.type.dims, y.type.shape) - if dim not in self.dims] + x_shape = [ + size for dim, size in zip(x.type.dims, x.type.shape) if dim not in self.dims + ] + y_shape = [ + size for dim, size in zip(y.type.dims, y.type.shape) if dim not in self.dims + ] # Combine remaining dimensions if self.sum_result: @@ -229,18 +231,22 @@ def dot(x, y, dims: str | Iterable[Hashable] | EllipsisType | None = None): if isinstance(dims, Iterable): for dim in dims: if dim not in x.type.dims: - raise ValueError(f"Dimension {dim} not found in first input {x.type.dims}") + raise ValueError( + f"Dimension {dim} not found in first input {x.type.dims}" + ) if dim not in y.type.dims: - raise ValueError(f"Dimension {dim} not found in second input {y.type.dims}") + raise ValueError( + f"Dimension {dim} not found in second input {y.type.dims}" + ) # If dims is ... , we have to sum over all remaining axes sum_result = dims is ... - + # Handle None and ... cases if dims is None or dims is ...: # Contract over all matching dimensions x_dims = set(x.type.dims) y_dims = set(y.type.dims) dims = tuple(x_dims & y_dims) - + return XDot(dims=dims, sum_result=sum_result)(x, y) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 9d8d7f4112..c5a9822136 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -174,10 +174,8 @@ def test_dot(): fn = xr_function([x, y], z) # Use outer product to create test data with diverse values - x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), - dims=("a", "b")) - y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), - dims=("b", "c")) + x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b")) + y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c")) z_test = fn(x_test, y_test) expected = x_test.dot(y_test) xr_assert_allclose(z_test, expected) @@ -188,14 +186,14 @@ def test_dot(): z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim="b") xr_assert_allclose(z_test, expected) - + # Test matrix-matrix dot product with list of dims z = x.dot(y, dims=["b"]) fn = xr_function([x, y], z) z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=["b"]) xr_assert_allclose(z_test, expected) - + # Test matrix-matrix dot product with ellipsis if True: z = x.dot(y, dims=...) @@ -205,8 +203,8 @@ def test_dot(): xr_assert_allclose(z_test, expected) # Test a case where there are two dimensions to contract over - x = xtensor("x", dims=("a", "b", 'c'), shape=(2, 3, 4)) - y = xtensor("y", dims=("b", "c", 'd'), shape=(3, 4, 5)) + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5)) z = x.dot(y) fn = xr_function([x, y], z) @@ -231,11 +229,3 @@ def test_dot(): z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=...) xr_assert_allclose(z_test, expected) - - - - - - - - \ No newline at end of file From 18501bfe9717015c3a15cd91ce3237906f9b0212 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sat, 14 Jun 2025 20:07:38 -0400 Subject: [PATCH 04/12] Adding shape checking at rewrite time --- pytensor/xtensor/math.py | 4 ++-- pytensor/xtensor/rewriting/math.py | 14 ++++++++++++-- tests/xtensor/test_math.py | 26 ++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 2f8d946254..7937722fab 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -1,5 +1,5 @@ import sys -from collections.abc import Hashable, Iterable +from collections.abc import Iterable from types import EllipsisType import numpy as np @@ -190,7 +190,7 @@ def make_node(self, x, y): return Apply(self, [x, y], [out]) -def dot(x, y, dims: str | Iterable[Hashable] | EllipsisType | None = None): +def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None): """Matrix multiplication between two XTensorVariables. This operation performs matrix multiplication between two tensors, automatically diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index 51e940442e..09fb899e20 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -20,11 +20,21 @@ def lower_dot(fgraph, node): x_tensor = tensor_from_xtensor(x) y_tensor = tensor_from_xtensor(y) - # Get axes to contract for each input + # Get the axes for contraction x_axes = [x.type.dims.index(dim) for dim in node.op.dims] y_axes = [y.type.dims.index(dim) for dim in node.op.dims] - # Perform dot product + # Check that shapes match along contracted dimensions + for dim in node.op.dims: + x_idx = x.type.dims.index(dim) + y_idx = y.type.dims.index(dim) + if x.type.shape[x_idx] != y.type.shape[y_idx]: + raise ValueError( + "Input arrays have inconsistent type shape along the axes " + f"that are to be reduced with tensordot: {x.type.shape[x_idx]} != {y.type.shape[y_idx]}" + ) + + # Perform the tensordot operation out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes)) # Sum over all remaining axes if needed diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index c5a9822136..cb6c631b07 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -229,3 +229,29 @@ def test_dot(): z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=...) xr_assert_allclose(z_test, expected) + + +def test_dot_errors(): + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + with pytest.raises(ValueError, match="Dimension c not found in first input"): + x.dot(y, dims=["c"]) + with pytest.raises(ValueError, match="Dimension a not found in second input"): + x.dot(y, dims=["a"]) + + # Test a case where there are no matching dimensions + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) + with pytest.raises(ValueError, match="cannot reindex or align along dimension"): + x_test.dot(y_test) + + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(4, 5)) + with pytest.raises( + ValueError, match="Input arrays have inconsistent type shape along the axes" + ): + z = x.dot(y) + fn = function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) + fn(x_test, y_test) From 23b17992a81cd5e8c986c599624c10d357a92348 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sat, 14 Jun 2025 20:33:57 -0400 Subject: [PATCH 05/12] Cleanup --- tests/xtensor/test_math.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index cb6c631b07..e1c5a38fca 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -167,13 +167,19 @@ def test_dot(): expected = x_test.dot(y_test) xr_assert_allclose(z_test, expected) + # Test matrix-vector dot product with ellipsis + z = x.dot(y, dims=...) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) + # Test matrix-matrix dot product x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = xtensor("y", dims=("b", "c"), shape=(3, 4)) z = x.dot(y) fn = xr_function([x, y], z) - # Use outer product to create test data with diverse values x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b")) y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c")) z_test = fn(x_test, y_test) @@ -195,14 +201,13 @@ def test_dot(): xr_assert_allclose(z_test, expected) # Test matrix-matrix dot product with ellipsis - if True: - z = x.dot(y, dims=...) - fn = xr_function([x, y], z) - z_test = fn(x_test, y_test) - expected = x_test.dot(y_test, dim=...) - xr_assert_allclose(z_test, expected) - - # Test a case where there are two dimensions to contract over + z = x.dot(y, dims=...) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) + + # Test a case where there are two dimensions to sum over x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5)) z = x.dot(y) @@ -222,13 +227,12 @@ def test_dot(): xr_assert_allclose(z_test, expected) # Same but with ellipses - if True: - z = x.dot(y, dims=...) - fn = xr_function([x, y], z) + z = x.dot(y, dims=...) + fn = xr_function([x, y], z) - z_test = fn(x_test, y_test) - expected = x_test.dot(y_test, dim=...) - xr_assert_allclose(z_test, expected) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test, dim=...) + xr_assert_allclose(z_test, expected) def test_dot_errors(): From 497c974ef2ff7ad1a2c14602e3d2261752604eaa Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sun, 15 Jun 2025 09:32:15 -0400 Subject: [PATCH 06/12] Compose XDot and Sum --- pytensor/xtensor/math.py | 25 +++++++++++++------------ pytensor/xtensor/rewriting/math.py | 5 ----- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 7937722fab..fe60612379 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -150,15 +150,12 @@ class XDot(XOp): ---------- dims : tuple of str The dimensions to contract over. If None, will contract over all matching dimensions. - sum_result : bool - If True, sum over all remaining axes after contraction (for full contraction, e.g. dims=...). """ - __props__ = ("dims", "sum_result") + __props__ = ("dims",) - def __init__(self, dims: Iterable[str], sum_result: bool = False): + def __init__(self, dims: Iterable[str]): self.dims = dims - self.sum_result = sum_result super().__init__() def make_node(self, x, y): @@ -176,12 +173,8 @@ def make_node(self, x, y): ] # Combine remaining dimensions - if self.sum_result: - out_dims = () - out_shape = () - else: - out_dims = tuple(x_dims + y_dims) - out_shape = tuple(x_shape + y_shape) + out_dims = tuple(x_dims + y_dims) + out_shape = tuple(x_shape + y_shape) # Determine output dtype out_dtype = upcast(x.type.dtype, y.type.dtype) @@ -249,4 +242,12 @@ def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None): y_dims = set(y.type.dims) dims = tuple(x_dims & y_dims) - return XDot(dims=dims, sum_result=sum_result)(x, y) + result = XDot(dims=dims)(x, y) + + if sum_result: + from pytensor.xtensor.reduction import sum as xtensor_sum + + # Sum over all remaining axes + result = xtensor_sum(result, dim=...) + + return result diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index 09fb899e20..2a94ef6130 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -37,10 +37,5 @@ def lower_dot(fgraph, node): # Perform the tensordot operation out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes)) - # Sum over all remaining axes if needed - if node.op.sum_result: - # Sum over all remaining dimensions - out_tensor = out_tensor.sum(axis=None) - # Convert back to xtensor return [xtensor_from_tensor(out_tensor, out.type.dims)] From e81657ed2734d4a3ec805615921f441309477dee Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sun, 15 Jun 2025 16:17:30 -0400 Subject: [PATCH 07/12] Generalizing XDot --- pytensor/xtensor/math.py | 80 +++++++++++++++++--------------------- pytensor/xtensor/type.py | 4 +- tests/xtensor/test_math.py | 43 ++++++++++++++------ 3 files changed, 69 insertions(+), 58 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index fe60612379..fef00d9148 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -162,19 +162,15 @@ def make_node(self, x, y): x = as_xtensor(x) y = as_xtensor(y) - # Filter out contracted dimensions - x_dims = [dim for dim in x.type.dims if dim not in self.dims] - y_dims = [dim for dim in y.type.dims if dim not in self.dims] - x_shape = [ - size for dim, size in zip(x.type.dims, x.type.shape) if dim not in self.dims - ] - y_shape = [ - size for dim, size in zip(y.type.dims, y.type.shape) if dim not in self.dims - ] - - # Combine remaining dimensions - out_dims = tuple(x_dims + y_dims) - out_shape = tuple(x_shape + y_shape) + x_shape_dict = dict(zip(x.type.dims, x.type.shape)) + y_shape_dict = dict(zip(y.type.dims, y.type.shape)) + shape_dict = {**x_shape_dict, **y_shape_dict} + + # Determine output dimensions + out_dims = tuple(d for d in shape_dict if d not in self.dims) + + # Determine output shape + out_shape = tuple(shape_dict[d] for d in out_dims) # Determine output dtype out_dtype = upcast(x.type.dtype, y.type.dtype) @@ -183,7 +179,7 @@ def make_node(self, x, y): return Apply(self, [x, y], [out]) -def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None): +def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): """Matrix multiplication between two XTensorVariables. This operation performs matrix multiplication between two tensors, automatically @@ -195,7 +191,7 @@ def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None): First input tensor y : XTensorVariable Second input tensor - dims : str, Iterable[Hashable], EllipsisType, or None, optional + dim : str, Iterable[Hashable], EllipsisType, or None, optional The dimensions to contract over. If None, will contract over all matching dimensions. If Ellipsis (...), will contract over all dimensions. @@ -214,40 +210,34 @@ def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None): x = as_xtensor(x) y = as_xtensor(y) + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + intersection = x_dims & y_dims + union = x_dims | y_dims + # Canonicalize dims - if isinstance(dims, str): - dims = (dims,) - elif isinstance(dims, Iterable): - dims = tuple(dims) + if dim is None: + dim_set = intersection + elif dim is ...: + dim_set = union + elif isinstance(dim, str): + dim_set = {dim} + elif isinstance(dim, Iterable): + dim_set = set(dim) # Validate provided dims - if isinstance(dims, Iterable): - for dim in dims: - if dim not in x.type.dims: - raise ValueError( - f"Dimension {dim} not found in first input {x.type.dims}" - ) - if dim not in y.type.dims: - raise ValueError( - f"Dimension {dim} not found in second input {y.type.dims}" - ) - - # If dims is ... , we have to sum over all remaining axes - sum_result = dims is ... - - # Handle None and ... cases - if dims is None or dims is ...: - # Contract over all matching dimensions - x_dims = set(x.type.dims) - y_dims = set(y.type.dims) - dims = tuple(x_dims & y_dims) - - result = XDot(dims=dims)(x, y) - - if sum_result: - from pytensor.xtensor.reduction import sum as xtensor_sum + # Check if any dimension is not found in either input + for d in dim_set: + if d not in union: + raise ValueError(f"Dimension {d} not found in either input {y.type.dims}") + + dotted_dims = tuple(dim_set & intersection) + summed_dims = tuple(dim_set.difference(dotted_dims)) + + result = XDot(dims=dotted_dims)(x, y) + if summed_dims: # Sum over all remaining axes - result = xtensor_sum(result, dim=...) + result = result.sum(dim=summed_dims) return result diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 385ed3b494..f16a7432bf 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -650,9 +650,9 @@ def stack(self, dim, **dims): def unstack(self, dim, **dims): return px.shape.unstack(self, dim, **dims) - def dot(self, other, dims=None): + def dot(self, other, dim=None): """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" - return px.math.dot(self, other, dims=dims) + return px.math.dot(self, other, dim=dim) class XTensorConstantSignature(tuple): diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index e1c5a38fca..5c72097dd1 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -168,7 +168,7 @@ def test_dot(): xr_assert_allclose(z_test, expected) # Test matrix-vector dot product with ellipsis - z = x.dot(y, dims=...) + z = x.dot(y, dim=...) fn = xr_function([x, y], z) z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=...) @@ -186,22 +186,22 @@ def test_dot(): expected = x_test.dot(y_test) xr_assert_allclose(z_test, expected) - # Test matrix-matrix dot product with string dims - z = x.dot(y, dims="b") + # Test matrix-matrix dot product with string dim + z = x.dot(y, dim="b") fn = xr_function([x, y], z) z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim="b") xr_assert_allclose(z_test, expected) # Test matrix-matrix dot product with list of dims - z = x.dot(y, dims=["b"]) + z = x.dot(y, dim=["b"]) fn = xr_function([x, y], z) z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=["b"]) xr_assert_allclose(z_test, expected) # Test matrix-matrix dot product with ellipsis - z = x.dot(y, dims=...) + z = x.dot(y, dim=...) fn = xr_function([x, y], z) z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=...) @@ -220,28 +220,49 @@ def test_dot(): xr_assert_allclose(z_test, expected) # Same but with explicit dimensions - z = x.dot(y, dims=["b", "c"]) + z = x.dot(y, dim=["b", "c"]) fn = xr_function([x, y], z) z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=["b", "c"]) xr_assert_allclose(z_test, expected) # Same but with ellipses - z = x.dot(y, dims=...) + z = x.dot(y, dim=...) fn = xr_function([x, y], z) z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=...) xr_assert_allclose(z_test, expected) + # Dot product with sum + x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c")) + y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d")) + expected = x_test.dot(y_test, dim=("a", "b", "c")) + + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5)) + z = x.dot(y, dim=("a", "b", "c")) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + return + # Dot product with sum in the middle + # This is not supported yet + x_test = DataArray(np.arange(120).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d")) + y_test = DataArray(np.arange(360).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e")) + expected = x_test.dot(y_test, dim=("b", "d")) + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5)) + y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6)) + z = x.dot(y, dim=("b", "d")) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + def test_dot_errors(): x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = xtensor("y", dims=("b", "c"), shape=(3, 4)) - with pytest.raises(ValueError, match="Dimension c not found in first input"): - x.dot(y, dims=["c"]) - with pytest.raises(ValueError, match="Dimension a not found in second input"): - x.dot(y, dims=["a"]) # Test a case where there are no matching dimensions x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) From e44ca9fefa3a7470e64962e443f83c55be435444 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sun, 15 Jun 2025 19:35:30 -0400 Subject: [PATCH 08/12] Now with more einsum --- pytensor/xtensor/math.py | 9 +------ pytensor/xtensor/rewriting/math.py | 42 ++++++++++++++++++------------ tests/xtensor/test_math.py | 29 ++++++++++++++++++--- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index fef00d9148..ec19c52c01 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -231,13 +231,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): if d not in union: raise ValueError(f"Dimension {d} not found in either input {y.type.dims}") - dotted_dims = tuple(dim_set & intersection) - summed_dims = tuple(dim_set.difference(dotted_dims)) - - result = XDot(dims=dotted_dims)(x, y) - - if summed_dims: - # Sum over all remaining axes - result = result.sum(dim=summed_dims) + result = XDot(dims=tuple(dim_set))(x, y) return result diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index 2a94ef6130..3fa07aeab3 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -1,5 +1,8 @@ +from string import ascii_lowercase + from pytensor.graph import node_rewriter -from pytensor.tensor import tensordot +from pytensor.tensor import einsum +from pytensor.tensor.shape import reshape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.math import XDot from pytensor.xtensor.rewriting.utils import register_lower_xtensor @@ -20,22 +23,29 @@ def lower_dot(fgraph, node): x_tensor = tensor_from_xtensor(x) y_tensor = tensor_from_xtensor(y) - # Get the axes for contraction - x_axes = [x.type.dims.index(dim) for dim in node.op.dims] - y_axes = [y.type.dims.index(dim) for dim in node.op.dims] + # Collect all dimension names across inputs and output + all_dims = list( + dict.fromkeys(x.type.dims + y.type.dims + out.type.dims) + ) # preserve order + if len(all_dims) > len(ascii_lowercase): + raise ValueError("Too many dimensions to map to einsum subscripts") + + dim_to_char = dict(zip(all_dims, ascii_lowercase)) + + # Build einsum string + x_subs = "".join(dim_to_char[d] for d in x.type.dims) + y_subs = "".join(dim_to_char[d] for d in y.type.dims) + out_subs = "".join(dim_to_char[d] for d in out.type.dims) + einsum_str = f"{x_subs},{y_subs}->{out_subs}" - # Check that shapes match along contracted dimensions - for dim in node.op.dims: - x_idx = x.type.dims.index(dim) - y_idx = y.type.dims.index(dim) - if x.type.shape[x_idx] != y.type.shape[y_idx]: - raise ValueError( - "Input arrays have inconsistent type shape along the axes " - f"that are to be reduced with tensordot: {x.type.shape[x_idx]} != {y.type.shape[y_idx]}" - ) + # Perform the einsum operation + out_tensor = einsum(einsum_str, x_tensor, y_tensor) - # Perform the tensordot operation - out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes)) + # Reshape to match the expected output shape + try: + out_tensor = reshape(out_tensor, out.type.shape) + except (TypeError, ValueError): + # Skip reshaping if symbolic shapes are present + pass - # Convert back to xtensor return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 5c72097dd1..4ec5b7e76b 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -246,11 +246,10 @@ def test_dot(): z_test = fn(x_test, y_test) xr_assert_allclose(z_test, expected) - return # Dot product with sum in the middle # This is not supported yet - x_test = DataArray(np.arange(120).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d")) - y_test = DataArray(np.arange(360).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e")) + x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d")) + y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e")) expected = x_test.dot(y_test, dim=("b", "d")) x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5)) y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6)) @@ -259,6 +258,27 @@ def test_dot(): z_test = fn(x_test, y_test) xr_assert_allclose(z_test, expected) + # Same but with first two dims + expected = x_test.dot(y_test, dim=["a", "b"]) + z = x.dot(y, dim=["a", "b"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + # Same but with last two + expected = x_test.dot(y_test, dim=["d", "e"]) + z = x.dot(y, dim=["d", "e"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + + # Same but with every other dim + expected = x_test.dot(y_test, dim=["a", "c", "e"]) + z = x.dot(y, dim=["a", "c", "e"]) + fn = xr_function([x, y], z) + z_test = fn(x_test, y_test) + xr_assert_allclose(z_test, expected) + def test_dot_errors(): x = xtensor("x", dims=("a", "b"), shape=(2, 3)) @@ -273,7 +293,8 @@ def test_dot_errors(): x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = xtensor("y", dims=("b", "c"), shape=(4, 5)) with pytest.raises( - ValueError, match="Input arrays have inconsistent type shape along the axes" + ValueError, + match="Size of label 'b' for operand 1.*does not match previous terms", ): z = x.dot(y) fn = function([x, y], z) From 6ce20d0da8193251c552cf8eee0f31dcca1ab749 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sun, 15 Jun 2025 19:41:24 -0400 Subject: [PATCH 09/12] Cleanup --- tests/xtensor/test_math.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 4ec5b7e76b..d43ea1c694 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -155,14 +155,14 @@ def test_cast(): def test_dot(): """Test basic dot product operations.""" - # Test matrix-vector dot product - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = xtensor("y", dims=("b",), shape=(3,)) + # Test matrix-vector dot product (with multiple-letter dim names) + x = xtensor("x", dims=("aa", "bb"), shape=(2, 3)) + y = xtensor("y", dims=("bb",), shape=(3,)) z = x.dot(y) fn = xr_function([x, y], z) - x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) - y_test = DataArray(np.ones(3), dims=("b",)) + x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb")) + y_test = DataArray(np.ones(3), dims=("bb",)) z_test = fn(x_test, y_test) expected = x_test.dot(y_test) xr_assert_allclose(z_test, expected) @@ -229,7 +229,6 @@ def test_dot(): # Same but with ellipses z = x.dot(y, dim=...) fn = xr_function([x, y], z) - z_test = fn(x_test, y_test) expected = x_test.dot(y_test, dim=...) xr_assert_allclose(z_test, expected) @@ -247,7 +246,6 @@ def test_dot(): xr_assert_allclose(z_test, expected) # Dot product with sum in the middle - # This is not supported yet x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d")) y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e")) expected = x_test.dot(y_test, dim=("b", "d")) From ed6bbd65bf89a8d09a81e9f81a2a5bbb2a77bd85 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Mon, 16 Jun 2025 10:21:30 -0400 Subject: [PATCH 10/12] Handle symbolic shapes --- pytensor/xtensor/rewriting/math.py | 10 +++++----- tests/xtensor/test_math.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index 3fa07aeab3..dcf50eb700 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -41,11 +41,11 @@ def lower_dot(fgraph, node): # Perform the einsum operation out_tensor = einsum(einsum_str, x_tensor, y_tensor) - # Reshape to match the expected output shape - try: + # Check if we have symbolic shapes + sym_shape = any(not isinstance(s, int) for s in out.type.shape) + + # If we have concrete shapes, reshape to match them + if not sym_shape: out_tensor = reshape(out_tensor, out.type.shape) - except (TypeError, ValueError): - # Skip reshaping if symbolic shapes are present - pass return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index d43ea1c694..d7718a5c69 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -277,6 +277,17 @@ def test_dot(): z_test = fn(x_test, y_test) xr_assert_allclose(z_test, expected) + # Test symbolic shapes + x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic + y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic + z = x.dot(y) + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((3, 4)), dims=("b", "c")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + def test_dot_errors(): x = xtensor("x", dims=("a", "b"), shape=(2, 3)) From 9f331174ddc1907ed40fbda0e335ff3ac87e9faa Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Mon, 16 Jun 2025 17:31:57 -0400 Subject: [PATCH 11/12] Use specify_shape (and better error tests) --- pytensor/xtensor/math.py | 13 ++++++++++++- pytensor/xtensor/rewriting/math.py | 10 +++------- tests/xtensor/test_math.py | 27 ++++++++++++++++----------- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index ec19c52c01..566bb3def6 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -229,7 +229,18 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): # Check if any dimension is not found in either input for d in dim_set: if d not in union: - raise ValueError(f"Dimension {d} not found in either input {y.type.dims}") + raise ValueError(f"Dimension {d} not found in either input") + + # Check for dimension size mismatches (concrete only) + for dim in intersection: + x_idx = x.type.dims.index(dim) + y_idx = y.type.dims.index(dim) + if ( + isinstance(x.type.shape[x_idx], int) + and isinstance(y.type.shape[y_idx], int) + and x.type.shape[x_idx] != y.type.shape[y_idx] + ): + raise ValueError(f"Size of dim '{dim}' does not match") result = XDot(dims=tuple(dim_set))(x, y) diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index dcf50eb700..850d91fad3 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -2,7 +2,7 @@ from pytensor.graph import node_rewriter from pytensor.tensor import einsum -from pytensor.tensor.shape import reshape +from pytensor.tensor.shape import specify_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.math import XDot from pytensor.xtensor.rewriting.utils import register_lower_xtensor @@ -41,11 +41,7 @@ def lower_dot(fgraph, node): # Perform the einsum operation out_tensor = einsum(einsum_str, x_tensor, y_tensor) - # Check if we have symbolic shapes - sym_shape = any(not isinstance(s, int) for s in out.type.shape) - - # If we have concrete shapes, reshape to match them - if not sym_shape: - out_tensor = reshape(out_tensor, out.type.shape) + # Reshape to match the output shape + out_tensor = specify_shape(out_tensor, out.type.shape) return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index d7718a5c69..088948de85 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -290,23 +290,28 @@ def test_dot(): def test_dot_errors(): + # No matching dimensions x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + with pytest.raises(ValueError, match="Dimension e not found in either input"): + x.dot(y, dim="e") - # Test a case where there are no matching dimensions - x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) - y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) - with pytest.raises(ValueError, match="cannot reindex or align along dimension"): - x_test.dot(y_test) - + # Concrete dimension size mismatches x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = xtensor("y", dims=("b", "c"), shape=(4, 5)) with pytest.raises( ValueError, - match="Size of label 'b' for operand 1.*does not match previous terms", + match="Size of dim 'b' does not match", ): - z = x.dot(y) - fn = function([x, y], z) - x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) - y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) + x.dot(y) + + # Symbolic dimension size mismatches + x = xtensor("x", dims=("a", "b"), shape=(2, None)) + y = xtensor("y", dims=("b", "c"), shape=(None, 3)) + z = x.dot(y) + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) + # Doesn't fail until the rewrite + with pytest.raises(ValueError, match="not aligned"): fn(x_test, y_test) From 1caf97738a05999a4015097fc3ce758e182f422e Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 17 Jun 2025 13:11:03 -0400 Subject: [PATCH 12/12] Moving shape check to make_node --- pytensor/xtensor/math.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 566bb3def6..55953ee452 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -164,9 +164,20 @@ def make_node(self, x, y): x_shape_dict = dict(zip(x.type.dims, x.type.shape)) y_shape_dict = dict(zip(y.type.dims, y.type.shape)) - shape_dict = {**x_shape_dict, **y_shape_dict} + + # Check for dimension size mismatches (concrete only) + for dim in self.dims: + x_shape = x_shape_dict.get(dim, None) + y_shape = y_shape_dict.get(dim, None) + if ( + isinstance(x_shape, int) + and isinstance(y_shape, int) + and x_shape != y_shape + ): + raise ValueError(f"Size of dim '{dim}' does not match") # Determine output dimensions + shape_dict = {**x_shape_dict, **y_shape_dict} out_dims = tuple(d for d in shape_dict if d not in self.dims) # Determine output shape @@ -231,17 +242,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): if d not in union: raise ValueError(f"Dimension {d} not found in either input") - # Check for dimension size mismatches (concrete only) - for dim in intersection: - x_idx = x.type.dims.index(dim) - y_idx = y.type.dims.index(dim) - if ( - isinstance(x.type.shape[x_idx], int) - and isinstance(y.type.shape[y_idx], int) - and x.type.shape[x_idx] != y.type.shape[y_idx] - ): - raise ValueError(f"Size of dim '{dim}' does not match") - result = XDot(dims=tuple(dim_set))(x, y) return result