Skip to content

Implement xarray-like labeled tensors and semantics #1411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -82,11 +82,12 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse --ignore=tests/xtensor"
- "tests/scan"
- "tests/sparse"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
- "tests/tensor/conv"
- "tests/tensor/rewriting"
- "tests/tensor/test_math.py"
@@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
@@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
@@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
@@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}

- name: Run tests
5 changes: 5 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
@@ -67,6 +67,8 @@ def register_linker(name, linker):
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
@@ -77,6 +79,7 @@ def register_linker(name, linker):
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
OPT_MINIMUM.name = "OPT_MINIMUM"
OPT_MERGE.name = "OPT_MERGE"
OPT_FAST_RUN.name = "OPT_FAST_RUN"
OPT_FAST_RUN_STABLE.name = "OPT_FAST_RUN_STABLE"
@@ -95,6 +98,7 @@ def register_linker(name, linker):
None: OPT_NONE,
"None": OPT_NONE,
"merge": OPT_MERGE,
"minimum_compile": OPT_MINIMUM,
"o4": OPT_FAST_RUN,
"o3": OPT_O3,
"o2": OPT_O2,
@@ -191,6 +195,7 @@ def apply(self, fgraph):
"merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0
)


# After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes.
# The opt should not do anything that need shape inference.
23 changes: 11 additions & 12 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
@@ -33,11 +33,8 @@ def register_view_op_c_code(type, code, version=()):
ViewOp.c_code_and_version[type] = (code, version)


class ViewOp(COp):
"""
Returns an inplace view of the input. Used internally by PyTensor.

"""
class TypeCastingOp(COp):
"""Op that performs a graph-level type cast operation, but has no effect computation-wise (identity function)."""

view_map = {0: [0]}
# Mapping from Type to C code (and version) to use.
@@ -47,13 +44,8 @@ class ViewOp(COp):
__props__: tuple = ()
_f16_ok: bool = True

def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inp, out):
(x,) = inp
(z,) = out
z[0] = x
def perform(self, node, inputs, outputs_storage):
outputs_storage[0][0] = inputs[0]

def __str__(self):
return f"{self.__class__.__name__}"
@@ -90,6 +82,13 @@ def c_code_cache_version(self):

return tuple(version)


class ViewOp(TypeCastingOp):
"""Returns an inplace view of the input. Used internally by PyTensor."""

def make_node(self, x):
return Apply(self, [x], [x.type()])

def infer_shape(self, fgraph, node, input_shapes):
return input_shapes

10 changes: 5 additions & 5 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@

from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
@@ -111,12 +111,12 @@ def deepcopyop(x):
return deepcopyop


@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op, **kwargs):
def viewop(x):
@jax_funcify.register(TypeCastingOp)
def jax_funcify_TypeCastingOp(op, **kwargs):
def type_cast(x):
return x

return viewop
return type_cast


@jax_funcify.register(OpFromGraph)
10 changes: 5 additions & 5 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import numpy as np

from pytensor.compile.ops import ViewOp
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
@@ -198,14 +198,14 @@ def cast(x):


@numba_basic.numba_njit
def viewop(x):
def identity(x):
return x


@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
return numba_basic.global_numba_func(viewop)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(identity)


@numba_basic.numba_njit
10 changes: 9 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
@@ -71,6 +71,14 @@
)


@pytorch_funcify.register(TypeCastingOp)
def pytorch_funcify_CastingOp(op, node, **kwargs):
def type_cast(x):
return x

Check warning on line 77 in pytensor/link/pytorch/dispatch/basic.py

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L76-L77

Added lines #L76 - L77 were not covered by tests

return type_cast

Check warning on line 79 in pytensor/link/pytorch/dispatch/basic.py

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L79

Added line #L79 was not covered by tests


@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
@@ -4551,7 +4551,7 @@ def ix_(*args):
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1)))
out.append(new)
return tuple(out)

18 changes: 0 additions & 18 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)


class CumsumOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj


class CumprodOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj


def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.

4 changes: 2 additions & 2 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
@@ -1625,8 +1625,7 @@ def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)


nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
nbinom = negative_binomial = NegBinomialRV()


class BetaBinomialRV(ScipyRandomVariable):
@@ -1808,6 +1807,7 @@ def rng_fn(cls, rng, n, p, size):

multinomial = MultinomialRV()


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


19 changes: 18 additions & 1 deletion pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
@@ -32,7 +33,20 @@
from pytensor.tensor.variable import TensorVariable


class RandomVariable(Op):
class RNGConsumerOp(Op):
"""Baseclass for Ops that consume RNGs."""

@abc.abstractmethod
def update(self, node: Apply) -> dict[Variable, Variable]:
"""Symbolic update expression for input RNG variables.

Returns a dictionary with the symbolic expressions required for correct updating
of RNG variables in repeated function evaluations.
"""
pass

Check warning on line 46 in pytensor/tensor/random/op.py

Codecov / codecov/patch

pytensor/tensor/random/op.py#L46

Added line #L46 was not covered by tests


class RandomVariable(RNGConsumerOp):
"""An `Op` that produces a sample from a random variable.

This is essentially `RandomFunction`, except that it removes the
@@ -123,6 +137,9 @@
if self.inplace:
self.destroy_map = {0: [0]}

def update(self, node: Apply) -> dict[Variable, Variable]:
return {node.inputs[0]: node.outputs[0]}

Check warning on line 141 in pytensor/tensor/random/op.py

Codecov / codecov/patch

pytensor/tensor/random/op.py#L141

Added line #L141 was not covered by tests

def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.

4 changes: 1 addition & 3 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var]


@register_infer_shape
@node_rewriter([Assert])
def local_remove_all_assert(fgraph, node):
r"""A rewrite that removes all `Assert`\s from a graph.
@@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section.

"""
if not isinstance(node.op, Assert):
return

return [node.inputs[0]]


29 changes: 29 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code


@@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))


def get_static_shape_from_size_variables(
size_vars: Sequence[Variable],
) -> tuple[int | None, ...]:
"""Get static shape from size variables.

Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from pytensor.tensor.basic import get_scalar_constant_value

static_lengths: list[None | int] = [None] * len(size_vars)
for i, length in enumerate(size_vars):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_lengths[i] = int(static_length)
return tuple(static_lengths)
3 changes: 3 additions & 0 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
@@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
if ds_op.new_order == tuple(range(self.type.ndim)):
# No-op
return self
return ds_op(self)

def flatten(self, ndim=1):
14 changes: 14 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
xtensor_constant,
)


warnings.warn("xtensor module is experimental and full of bugs")
100 changes: 100 additions & 0 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections.abc import Sequence

from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


class XOp(Op):
"""A base class for XOps that shouldn't be materialized"""

def perform(self, node, inputs, outputs):
raise NotImplementedError(

Check warning on line 13 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L13

Added line #L13 was not covered by tests
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType.
This is like a `ViewOp` but without the expectation the input and output have identical types.
"""


class TensorFromXTensor(XTypeCastOp):
__props__ = ()

def make_node(self, x):
if not isinstance(x.type, XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")

Check warning on line 30 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L30

Added line #L30 was not covered by tests
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
[x] = inputs
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]

Check warning on line 37 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L35-L37

Added lines #L35 - L37 were not covered by tests


tensor_from_xtensor = TensorFromXTensor()


class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)

def make_node(self, x):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")

Check warning on line 52 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L52

Added line #L52 was not covered by tests
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]

Check warning on line 58 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L57-L58

Added lines #L57 - L58 were not covered by tests


def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)


class Rename(XTypeCastOp):
__props__ = ("new_dims",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims

def make_node(self, x):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]

Check warning on line 80 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L78-L80

Added lines #L78 - L80 were not covered by tests


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
names = name_dict

Check warning on line 87 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L86-L87

Added lines #L86 - L87 were not covered by tests

x = as_xtensor(x)
old_names = x.type.dims
new_names = list(old_names)
for old_name, new_name in names.items():
try:
new_names[old_names.index(old_name)] = new_name
except ValueError:
raise ValueError(

Check warning on line 96 in pytensor/xtensor/basic.py

Codecov / codecov/patch

pytensor/xtensor/basic.py#L95-L96

Added lines #L95 - L96 were not covered by tests
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
219 changes: 219 additions & 0 deletions pytensor/xtensor/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# HERE LIE DRAGONS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new files need a license.

Copy link
Member Author

@ricardoV94 ricardoV94 Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucianopaz can we bring your pre-commit hook over?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a separate PR, wouldn't be surprised if have files missing it in main

# Useful links to make sense of all the numpy/xarray complexity
# https://numpy.org/devdocs//user/basics.indexing.html
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from typing import Literal

from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


def as_idx_variable(idx, indexed_dim: str):
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
raise TypeError(

Check warning on line 19 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L19

Added line #L19 was not covered by tests
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
)
if isinstance(idx, slice):
idx = make_slice(idx)
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
pass
elif (
isinstance(idx, tuple)
and len(idx) == 2
and (
isinstance(idx[0], str)
or (
isinstance(idx[0], tuple | list)
and all(isinstance(d, str) for d in idx[0])
)
)
):
# Special case for ("x", array) that xarray supports
dim, idx = idx
if isinstance(idx, Variable) and isinstance(idx.type, XTensorType):
raise IndexError(
f"Giving a dimension name to an XTensorVariable indexer is not supported: {(dim, idx)}. "
"Use .rename() instead."
)
if isinstance(dim, str):
dims = (dim,)
else:
dims = tuple(dim)
idx = as_xtensor(as_tensor(idx), dims=dims)
else:
# Must be integer / boolean indices, we already counted for None and slices
try:
idx = as_xtensor(idx)
except TypeError:
idx = as_tensor(idx)
if idx.type.ndim > 1:
# Same error that xarray raises
raise IndexError(

Check warning on line 57 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L57

Added line #L57 was not covered by tests
"Unlabeled multi-dimensional array cannot be used for indexing"
)
# This is implicitly an XTensorVariable with dim matching the indexed one
idx = xtensor_from_tensor(idx, dims=(indexed_dim,)[: idx.type.ndim])

if idx.type.dtype == "bool":
if idx.type.ndim != 1:
# xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379
# Otherwise, it is always restricted to 1d boolean indexing arrays
raise NotImplementedError(
"Only 1d boolean indexing arrays are supported"
)
if idx.type.dims != (indexed_dim,):
raise IndexError(
"Boolean indexer should be unlabeled or on the same dimension to the indexed array. "
f"Indexer is on {idx.type.dims} but the target dimension is {indexed_dim}."
)

# Convert to nonzero indices
idx = as_xtensor(idx.values.nonzero()[0], dims=idx.type.dims)

elif idx.type.dtype not in discrete_dtypes:
raise TypeError("Numerical indices must be integers or boolean")

Check warning on line 80 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L80

Added line #L80 was not covered by tests
return idx


def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
if dim_length is None:
return None

Check warning on line 86 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L86

Added line #L86 was not covered by tests
if isinstance(slc, Constant):
d = slc.data
start, stop, step = d.start, d.stop, d.step

Check warning on line 89 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L88-L89

Added lines #L88 - L89 were not covered by tests
elif slc.owner is None:
# It's a root variable no way of knowing what we're getting
return None

Check warning on line 92 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L92

Added line #L92 was not covered by tests
else:
# It's a MakeSliceOp
start, stop, step = slc.owner.inputs
if isinstance(start, Constant):
start = start.data
else:
return None

Check warning on line 99 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L99

Added line #L99 was not covered by tests
if isinstance(stop, Constant):
stop = stop.data
else:
return None

Check warning on line 103 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L103

Added line #L103 was not covered by tests
if isinstance(step, Constant):
step = step.data
else:
return None

Check warning on line 107 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L107

Added line #L107 was not covered by tests
return len(range(*slice(start, stop, step).indices(dim_length)))


class Index(XOp):
__props__ = ()

def make_node(self, x, *idxs):
x = as_xtensor(x)

if any(idx is Ellipsis for idx in idxs):
if idxs.count(Ellipsis) > 1:
raise IndexError("an index can only have a single ellipsis ('...')")

Check warning on line 119 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L119

Added line #L119 was not covered by tests
# Convert intermediate Ellipsis to slice(None)
ellipsis_loc = idxs.index(Ellipsis)
n_implied_none_slices = x.type.ndim - (len(idxs) - 1)
idxs = (
*idxs[:ellipsis_loc],
*((slice(None),) * n_implied_none_slices),
*idxs[ellipsis_loc + 1 :],
)

x_ndim = x.type.ndim
x_dims = x.type.dims
x_shape = x.type.shape
out_dims = []
out_shape = []

def combine_dim_info(idx_dim, idx_dim_shape):
if idx_dim not in out_dims:
# First information about the dimension length
out_dims.append(idx_dim)
out_shape.append(idx_dim_shape)
else:
# Dim already introduced in output by a previous index
# Update static shape or raise if incompatible
out_dim_pos = out_dims.index(idx_dim)
out_dim_shape = out_shape[out_dim_pos]
if out_dim_shape is None:
# We don't know the size of the dimension yet
out_shape[out_dim_pos] = idx_dim_shape
elif idx_dim_shape is not None and idx_dim_shape != out_dim_shape:
raise IndexError(
f"Dimension of indexers mismatch for dim {idx_dim}"
)

if len(idxs) > x_ndim:
raise IndexError("Too many indices")

Check warning on line 154 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L154

Added line #L154 was not covered by tests

idxs = [
as_idx_variable(idx, dim) for idx, dim in zip(idxs, x_dims, strict=False)
]

for i, idx in enumerate(idxs):
if isinstance(idx.type, SliceType):
idx_dim = x_dims[i]
idx_dim_shape = get_static_slice_length(idx, x_shape[i])
combine_dim_info(idx_dim, idx_dim_shape)
else:
if idx.type.ndim == 0:
# Scalar index, dimension is dropped
continue

assert isinstance(idx.type, XTensorType)

idx_dims = idx.type.dims
for idx_dim in idx_dims:
idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)]
combine_dim_info(idx_dim, idx_dim_shape)

for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]):
# Add back any unindexed dimensions
if dim_i not in out_dims:
# If the dimension was not indexed, we keep it as is
combine_dim_info(dim_i, shape_i)

output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])


index = Index()


class IndexUpdate(XOp):
__props__ = ("mode",)

def __init__(self, mode: Literal["set", "inc"]):
if mode not in ("set", "inc"):
raise ValueError("mode must be 'set' or 'inc'")

Check warning on line 195 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L195

Added line #L195 was not covered by tests
self.mode = mode

def make_node(self, x, y, *idxs):
# Call Index on (x, *idxs) to process inputs and infer output type
x_view_node = index.make_node(x, *idxs)
x, *idxs = x_view_node.inputs
[x_view] = x_view_node.outputs

try:
y = as_xtensor(y)
except TypeError:
y = as_xtensor(as_tensor(y), dims=x_view.type.dims)

Check warning on line 207 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L206-L207

Added lines #L206 - L207 were not covered by tests

if not set(y.type.dims).issubset(x_view.type.dims):
raise ValueError(

Check warning on line 210 in pytensor/xtensor/indexing.py

Codecov / codecov/patch

pytensor/xtensor/indexing.py#L210

Added line #L210 was not covered by tests
f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}"
)

out = x.type()
return Apply(self, [x, y, *idxs], [out])


index_assignment = IndexUpdate("set")
index_increment = IndexUpdate("inc")
70 changes: 70 additions & 0 deletions pytensor/xtensor/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from collections.abc import Sequence
from typing import Literal

from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.xtensor.type import as_xtensor
from pytensor.xtensor.vectorization import XBlockwise


def cholesky(
x,
lower: bool = True,
*,
check_finite: bool = False,
overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise",
dims: Sequence[str],
):
if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")

Check warning on line 19 in pytensor/xtensor/linalg.py

Codecov / codecov/patch

pytensor/xtensor/linalg.py#L19

Added line #L19 was not covered by tests

core_op = Cholesky(
lower=lower,
check_finite=check_finite,
overwrite_a=overwrite_a,
on_error=on_error,
)
core_dims = (
((dims[0], dims[1]),),
((dims[0], dims[1]),),
)
x_op = XBlockwise(core_op, core_dims=core_dims)
return x_op(x)


def solve(
a,
b,
dims: Sequence[str],
assume_a="gen",
lower: bool = False,
check_finite: bool = False,
):
a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
output_core_dims: tuple[tuple[str] | tuple[str, str]]
if len(dims) == 2:
b_ndim = 1
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
# The shared dim disappears in the output
output_core_dims = ((m1_dim,),)
elif len(dims) == 3:
b_ndim = 2
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
# The shared dim disappears in the output
output_core_dims = ((m1_dim, n_dim),)
else:
raise ValueError("Solve dims must have length 2 or 3")

Check warning on line 61 in pytensor/xtensor/linalg.py

Codecov / codecov/patch

pytensor/xtensor/linalg.py#L61

Added line #L61 was not covered by tests

core_op = Solve(
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
)
x_op = XBlockwise(
core_op,
core_dims=(input_core_dims, output_core_dims),
)
return x_op(a, b)
252 changes: 252 additions & 0 deletions pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import sys
from collections.abc import Iterable
from types import EllipsisType

import numpy as np

import pytensor.scalar as ps
from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.scalar import ScalarOp
from pytensor.scalar.basic import _cast_mapping, upcast
from pytensor.xtensor.basic import XOp, as_xtensor
from pytensor.xtensor.type import xtensor
from pytensor.xtensor.vectorization import XElemwise


this_module = sys.modules[__name__]


def _as_xelemwise(core_op: ScalarOp) -> XElemwise:
out = XElemwise(core_op)
out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables"
return out


abs = _as_xelemwise(ps.abs)
add = _as_xelemwise(ps.add)
logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_)
angle = _as_xelemwise(ps.angle)
arccos = _as_xelemwise(ps.arccos)
arccosh = _as_xelemwise(ps.arccosh)
arcsin = _as_xelemwise(ps.arcsin)
arcsinh = _as_xelemwise(ps.arcsinh)
arctan = _as_xelemwise(ps.arctan)
arctan2 = _as_xelemwise(ps.arctan2)
arctanh = _as_xelemwise(ps.arctanh)
betainc = _as_xelemwise(ps.betainc)
betaincinv = _as_xelemwise(ps.betaincinv)
ceil = _as_xelemwise(ps.ceil)
clip = _as_xelemwise(ps.clip)
complex = _as_xelemwise(ps.complex)
conjugate = conj = _as_xelemwise(ps.conj)
cos = _as_xelemwise(ps.cos)
cosh = _as_xelemwise(ps.cosh)
deg2rad = _as_xelemwise(ps.deg2rad)
equal = eq = _as_xelemwise(ps.eq)
erf = _as_xelemwise(ps.erf)
erfc = _as_xelemwise(ps.erfc)
erfcinv = _as_xelemwise(ps.erfcinv)
erfcx = _as_xelemwise(ps.erfcx)
erfinv = _as_xelemwise(ps.erfinv)
exp = _as_xelemwise(ps.exp)
exp2 = _as_xelemwise(ps.exp2)
expm1 = _as_xelemwise(ps.expm1)
floor = _as_xelemwise(ps.floor)
floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div)
gamma = _as_xelemwise(ps.gamma)
gammainc = _as_xelemwise(ps.gammainc)
gammaincc = _as_xelemwise(ps.gammaincc)
gammainccinv = _as_xelemwise(ps.gammainccinv)
gammaincinv = _as_xelemwise(ps.gammaincinv)
gammal = _as_xelemwise(ps.gammal)
gammaln = _as_xelemwise(ps.gammaln)
gammau = _as_xelemwise(ps.gammau)
greater_equal = ge = _as_xelemwise(ps.ge)
greater = gt = _as_xelemwise(ps.gt)
hyp2f1 = _as_xelemwise(ps.hyp2f1)
i0 = _as_xelemwise(ps.i0)
i1 = _as_xelemwise(ps.i1)
identity = _as_xelemwise(ps.identity)
imag = _as_xelemwise(ps.imag)
logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert)
isinf = _as_xelemwise(ps.isinf)
isnan = _as_xelemwise(ps.isnan)
iv = _as_xelemwise(ps.iv)
ive = _as_xelemwise(ps.ive)
j0 = _as_xelemwise(ps.j0)
j1 = _as_xelemwise(ps.j1)
jv = _as_xelemwise(ps.jv)
kve = _as_xelemwise(ps.kve)
less_equal = le = _as_xelemwise(ps.le)
log = _as_xelemwise(ps.log)
log10 = _as_xelemwise(ps.log10)
log1mexp = _as_xelemwise(ps.log1mexp)
log1p = _as_xelemwise(ps.log1p)
log2 = _as_xelemwise(ps.log2)
less = lt = _as_xelemwise(ps.lt)
mod = _as_xelemwise(ps.mod)
multiply = mul = _as_xelemwise(ps.mul)
negative = neg = _as_xelemwise(ps.neg)
not_equal = neq = _as_xelemwise(ps.neq)
logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_)
owens_t = _as_xelemwise(ps.owens_t)
polygamma = _as_xelemwise(ps.polygamma)
power = pow = _as_xelemwise(ps.pow)
psi = _as_xelemwise(ps.psi)
rad2deg = _as_xelemwise(ps.rad2deg)
real = _as_xelemwise(ps.real)
reciprocal = _as_xelemwise(ps.reciprocal)
round = _as_xelemwise(ps.round_half_to_even)
maximum = _as_xelemwise(ps.scalar_maximum)
minimum = _as_xelemwise(ps.scalar_minimum)
second = _as_xelemwise(ps.second)
sigmoid = _as_xelemwise(ps.sigmoid)
sign = _as_xelemwise(ps.sign)
sin = _as_xelemwise(ps.sin)
sinh = _as_xelemwise(ps.sinh)
softplus = _as_xelemwise(ps.softplus)
square = sqr = _as_xelemwise(ps.sqr)
sqrt = _as_xelemwise(ps.sqrt)
subtract = sub = _as_xelemwise(ps.sub)
where = switch = _as_xelemwise(ps.switch)
tan = _as_xelemwise(ps.tan)
tanh = _as_xelemwise(ps.tanh)
tri_gamma = _as_xelemwise(ps.tri_gamma)
true_divide = true_div = _as_xelemwise(ps.true_div)
trunc = _as_xelemwise(ps.trunc)
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor)

_xelemwise_cast_op: dict[str, XElemwise] = {}


def cast(x, dtype):
if dtype == "floatX":
dtype = config.floatX

Check warning on line 125 in pytensor/xtensor/math.py

Codecov / codecov/patch

pytensor/xtensor/math.py#L125

Added line #L125 was not covered by tests
else:
dtype = np.dtype(dtype).name

x = as_xtensor(x)
if x.type.dtype == dtype:
return x

Check warning on line 131 in pytensor/xtensor/math.py

Codecov / codecov/patch

pytensor/xtensor/math.py#L131

Added line #L131 was not covered by tests
if x.type.dtype.startswith("complex") and not dtype.startswith("complex"):
raise TypeError(
"Casting from complex to real is ambiguous: consider"
" real(), imag(), angle() or abs()"
)

if dtype not in _xelemwise_cast_op:
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
return _xelemwise_cast_op[dtype](x)


def softmax(x, dim=None):
exp_x = exp(x)
return exp_x / exp_x.sum(dim=dim)

Check warning on line 145 in pytensor/xtensor/math.py

Codecov / codecov/patch

pytensor/xtensor/math.py#L144-L145

Added lines #L144 - L145 were not covered by tests


class XDot(XOp):
"""Matrix multiplication between two XTensorVariables.
This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.
Parameters
----------
dims : tuple of str
The dimensions to contract over. If None, will contract over all matching dimensions.
"""

__props__ = ("dims",)

def __init__(self, dims: Iterable[str]):
self.dims = dims
super().__init__()

def make_node(self, x, y):
x = as_xtensor(x)
y = as_xtensor(y)

x_shape_dict = dict(zip(x.type.dims, x.type.shape))
y_shape_dict = dict(zip(y.type.dims, y.type.shape))

# Check for dimension size mismatches (concrete only)
for dim in self.dims:
x_shape = x_shape_dict.get(dim, None)
y_shape = y_shape_dict.get(dim, None)
if (
isinstance(x_shape, int)
and isinstance(y_shape, int)
and x_shape != y_shape
):
raise ValueError(f"Size of dim '{dim}' does not match")

# Determine output dimensions
shape_dict = {**x_shape_dict, **y_shape_dict}
out_dims = tuple(d for d in shape_dict if d not in self.dims)

# Determine output shape
out_shape = tuple(shape_dict[d] for d in out_dims)

# Determine output dtype
out_dtype = upcast(x.type.dtype, y.type.dtype)

out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, y], [out])


def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
"""Matrix multiplication between two XTensorVariables.
This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.
Parameters
----------
x : XTensorVariable
First input tensor
y : XTensorVariable
Second input tensor
dim : str, Iterable[Hashable], EllipsisType, or None, optional
The dimensions to contract over. If None, will contract over all matching dimensions.
If Ellipsis (...), will contract over all dimensions.
Returns
-------
XTensorVariable
The result of the matrix multiplication.
Examples
--------
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
>>> z = dot(x, y) # Result has dimensions ("a", "c")
>>> z = dot(x, y, dim=...) # Contract over all dimensions
"""
x = as_xtensor(x)
y = as_xtensor(y)

x_dims = set(x.type.dims)
y_dims = set(y.type.dims)
intersection = x_dims & y_dims
union = x_dims | y_dims

# Canonicalize dims
if dim is None:
dim_set = intersection
elif dim is ...:
dim_set = union
elif isinstance(dim, str):
dim_set = {dim}
elif isinstance(dim, Iterable):
dim_set = set(dim)

# Validate provided dims
# Check if any dimension is not found in either input
for d in dim_set:
if d not in union:
raise ValueError(f"Dimension {d} not found in either input")

result = XDot(dims=tuple(dim_set))(x, y)

return result
168 changes: 168 additions & 0 deletions pytensor/xtensor/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from collections.abc import Sequence
from functools import wraps
from typing import Literal

import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.math import sqrt
from pytensor.xtensor.vectorization import XRV


def _as_xrv(
core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None,
):
"""Helper function to define an XRV constructor.
Parameters
----------
core_op : RandomVariable
The core random variable operation to wrap.
core_inps_dims_map : Sequence[Sequence[int]] | None, optional
A sequence of sequences mapping the core dimensions (specified by the user)
for each input parameter. This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for each input.
If None, it assumes the core dimensions are positional from left to right.
core_out_dims_map : Sequence[int] | None, optional
A sequence mapping the core dimensions (specified by the user) for the output variable.
This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for the output.
If None, it assumes the core dimensions are positional from left to right.
"""
if core_inps_dims_map is None:
# Assume core_dims map positionally from left to right
core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params]
if core_out_dims_map is None:
# Assume core_dims map positionally from left to right
core_out_dims_map = tuple(range(core_op.ndim_supp))

core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
)

@wraps(core_op)
def xrv_constructor(
*params,
core_dims: Sequence[str] | str | None = None,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
if core_dims is None:
core_dims = ()
if core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims to be specified"
)
elif isinstance(core_dims, str):
core_dims = (core_dims,)

Check warning on line 61 in pytensor/xtensor/random.py

Codecov / codecov/patch

pytensor/xtensor/random.py#L61

Added line #L61 was not covered by tests

if len(core_dims) != core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}"
)

full_input_core_dims = tuple(
tuple(core_dims[i] for i in inp_dims_map)
for inp_dims_map in core_inps_dims_map
)
full_output_core_dims = tuple(core_dims[i] for i in core_out_dims_map)
full_core_dims = (full_input_core_dims, full_output_core_dims)

if extra_dims is None:
extra_dims = {}

return XRV(
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys())
)(rng, *extra_dims.values(), *params)

return xrv_constructor


bernoulli = _as_xrv(ptr.bernoulli)
beta = _as_xrv(ptr.beta)
betabinom = _as_xrv(ptr.betabinom)
binomial = _as_xrv(ptr.binomial)
categorical = _as_xrv(ptr.categorical)
cauchy = _as_xrv(ptr.cauchy)
dirichlet = _as_xrv(ptr.dirichlet)
exponential = _as_xrv(ptr.exponential)
gamma = _as_xrv(ptr._gamma)
gengamma = _as_xrv(ptr.gengamma)
geometric = _as_xrv(ptr.geometric)
gumbel = _as_xrv(ptr.gumbel)
halfcauchy = _as_xrv(ptr.halfcauchy)
halfnormal = _as_xrv(ptr.halfnormal)
hypergeometric = _as_xrv(ptr.hypergeometric)
integers = _as_xrv(ptr.integers)
invgamma = _as_xrv(ptr.invgamma)
laplace = _as_xrv(ptr.laplace)
logistic = _as_xrv(ptr.logistic)
lognormal = _as_xrv(ptr.lognormal)
multinomial = _as_xrv(ptr.multinomial)
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial)
normal = _as_xrv(ptr.normal)
pareto = _as_xrv(ptr.pareto)
poisson = _as_xrv(ptr.poisson)
t = _as_xrv(ptr.t)
triangular = _as_xrv(ptr.triangular)
truncexpon = _as_xrv(ptr.truncexpon)
uniform = _as_xrv(ptr.uniform)
vonmises = _as_xrv(ptr.vonmises)
wald = _as_xrv(ptr.wald)
weibull = _as_xrv(ptr.weibull)


def multivariate_normal(
mean,
cov,
*,
core_dims: Sequence[str],
extra_dims=None,
rng=None,
method: Literal["cholesky", "svd", "eigh"] = "cholesky",
):
mean = as_xtensor(mean)
if len(core_dims) != 2:
raise ValueError(
f"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
)

# Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov
# This will be the core dimension of the output
if core_dims[0] not in mean.type.dims:
core_dims = core_dims[::-1]

xop = _as_xrv(ptr.MvNormalRV(method=method))
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)


def standard_normal(
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Standard normal random variable."""
return normal(0, 1, extra_dims=extra_dims, rng=rng)

Check warning on line 148 in pytensor/xtensor/random.py

Codecov / codecov/patch

pytensor/xtensor/random.py#L148

Added line #L148 was not covered by tests


def chisquare(
df,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Chi-square random variable."""
return gamma(df / 2.0, 2.0, extra_dims=extra_dims, rng=rng)

Check warning on line 157 in pytensor/xtensor/random.py

Codecov / codecov/patch

pytensor/xtensor/random.py#L157

Added line #L157 was not covered by tests


def rayleigh(
scale,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Rayleigh random variable."""

df = scale * 0 + 2 # Poor man's broadcasting, to pass dimensions of scale to the RV
return sqrt(chisquare(df, extra_dims=extra_dims, rng=rng)) * scale

Check warning on line 168 in pytensor/xtensor/random.py

Codecov / codecov/patch

pytensor/xtensor/random.py#L167-L168

Added lines #L167 - L168 were not covered by tests
69 changes: 69 additions & 0 deletions pytensor/xtensor/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# XTensor Module

This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.

A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
that labels the dimensions of the tensor.

Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.

The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
a regular tensor graph that can itself be evaluated as usual.

Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.

## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.

## Example

```python
import pytensor.tensor as pt
import pytensor.xtensor as px

a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))

ax = px.as_xtensor(a, dims=["x"])
bx = px.as_xtensor(b, dims=["y"])

zx = ax + bx
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))

z = zx.values
z.dprint()
# TensorFromXTensor [id A]
# └─ XElemwise{scalar_op=Add()} [id B]
# ├─ XTensorFromTensor{dims=('x',)} [id C]
# │ └─ a [id D]
# └─ XTensorFromTensor{dims=('y',)} [id E]
# └─ b [id F]
```

Once we compile the graph, no `XOp`s are left.

```python
import pytensor

with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)

# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)

fn.dprint()
# Add [id A] 2
# ├─ ExpandDims{axis=1} [id B] 1
# │ └─ a [id C]
# └─ ExpandDims{axis=0} [id D] 0
# └─ b [id E]
```



125 changes: 125 additions & 0 deletions pytensor/xtensor/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import typing
from collections.abc import Sequence
from functools import partial
from types import EllipsisType

import pytensor.scalar as ps
from pytensor.graph.basic import Apply
from pytensor.tensor.math import variadic_mul
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import neq, sqrt
from pytensor.xtensor.math import sqr as square
from pytensor.xtensor.type import as_xtensor, xtensor


REDUCE_DIM = str | Sequence[str] | EllipsisType | None


class XReduce(XOp):
__slots__ = ("binary_op", "dims")

def __init__(self, binary_op, dims: Sequence[str]):
super().__init__()
self.binary_op = binary_op
# Order of reduce dims doesn't change the behavior of the Op
self.dims = tuple(sorted(dims))

def make_node(self, x):
x = as_xtensor(x)
x_dims = x.type.dims
x_dims_set = set(x_dims)
reduce_dims_set = set(self.dims)
if x_dims_set == reduce_dims_set:
out_dims, out_shape = [], []
else:
if not reduce_dims_set.issubset(x_dims_set):
raise ValueError(

Check warning on line 36 in pytensor/xtensor/reduction.py

Codecov / codecov/patch

pytensor/xtensor/reduction.py#L36

Added line #L36 was not covered by tests
f"Reduced dims {self.dims} not found in array dimensions {x_dims}."
)
out_dims, out_shape = zip(
*[
(d, s)
for d, s in zip(x_dims, x.type.shape)
if d not in reduce_dims_set
]
)
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output])


def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str):
return (dim,)
elif dim is None or dim is Ellipsis:
x = as_xtensor(x)
return typing.cast(tuple[str], x.type.dims)
return dim


def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
dims = _process_user_dims(x, dim)
return XReduce(binary_op=binary_op, dims=dims)(x)


sum = partial(reduce, binary_op=ps.add)
prod = partial(reduce, binary_op=ps.mul)
max = partial(reduce, binary_op=ps.scalar_maximum)
min = partial(reduce, binary_op=ps.scalar_minimum)


def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
x = as_xtensor(x)
if x.type.dtype != "bool":
x = neq(x, 0)
return reduce(x, dim=dim, binary_op=binary_op)


all = partial(bool_reduce, binary_op=ps.and_)
any = partial(bool_reduce, binary_op=ps.or_)


def _infer_reduced_size(original_var, reduced_var):
reduced_dims = reduced_var.dims
return variadic_mul(

Check warning on line 83 in pytensor/xtensor/reduction.py

Codecov / codecov/patch

pytensor/xtensor/reduction.py#L82-L83

Added lines #L82 - L83 were not covered by tests
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
)


def mean(x, dim: REDUCE_DIM):
x = as_xtensor(x)
sum_x = sum(x, dim)
n = _infer_reduced_size(x, sum_x)
return sum_x / n

Check warning on line 92 in pytensor/xtensor/reduction.py

Codecov / codecov/patch

pytensor/xtensor/reduction.py#L89-L92

Added lines #L89 - L92 were not covered by tests


def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
x = as_xtensor(x)
x_mean = mean(x, dim)
n = _infer_reduced_size(x, x_mean)
return square(x - x_mean) / (n - ddof)

Check warning on line 99 in pytensor/xtensor/reduction.py

Codecov / codecov/patch

pytensor/xtensor/reduction.py#L96-L99

Added lines #L96 - L99 were not covered by tests


def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
return sqrt(var(x, dim, ddof=ddof))

Check warning on line 103 in pytensor/xtensor/reduction.py

Codecov / codecov/patch

pytensor/xtensor/reduction.py#L103

Added line #L103 was not covered by tests


class XCumReduce(XOp):
__props__ = ("binary_op", "dims")

def __init__(self, binary_op, dims: Sequence[str]):
self.binary_op = binary_op
self.dims = tuple(sorted(dims)) # Order doesn't matter

def make_node(self, x):
x = as_xtensor(x)
out = x.type()
return Apply(self, [x], [out])


def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
dims = _process_user_dims(x, dim)
return XCumReduce(dims=dims, binary_op=binary_op)(x)


cumsum = partial(cumreduce, binary_op=ps.add)
cumprod = partial(cumreduce, binary_op=ps.mul)
6 changes: 6 additions & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.indexing
import pytensor.xtensor.rewriting.math
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
62 changes: 62 additions & 0 deletions pytensor/xtensor/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from pytensor.graph import node_rewriter
from pytensor.tensor.basic import register_infer_shape
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
from pytensor.xtensor.basic import (
Rename,
TensorFromXTensor,
XTensorFromTensor,
xtensor_from_tensor,
)
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor(fgraph, node):
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
[x] = node.inputs
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
return [x.owner.inputs[0]]


@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[XTensorFromTensor])
def useless_xtensor_from_tensor(fgraph, node):
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
[x] = node.inputs
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
return [x.owner.inputs[0]]


@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor_of_rename(fgraph, node):
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
[renamed_x] = node.inputs
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
[x] = renamed_x.owner.inputs
return node.op(x, return_list=True)


@register_lower_xtensor
@node_rewriter(tracks=[Rename])
def useless_rename(fgraph, node):
"""
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims)
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims)
"""
[renamed_x] = node.inputs

Check warning on line 55 in pytensor/xtensor/rewriting/basic.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/basic.py#L55

Added line #L55 was not covered by tests
if renamed_x.owner:
if isinstance(renamed_x.owner.op, Rename):
[x] = renamed_x.owner.inputs
return [node.op(x)]

Check warning on line 59 in pytensor/xtensor/rewriting/basic.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/basic.py#L58-L59

Added lines #L58 - L59 were not covered by tests
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
[x] = renamed_x.owner.inputs
return [xtensor_from_tensor(x, dims=node.op.new_dims)]

Check warning on line 62 in pytensor/xtensor/rewriting/basic.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/basic.py#L61-L62

Added lines #L61 - L62 were not covered by tests
212 changes: 212 additions & 0 deletions pytensor/xtensor/rewriting/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from itertools import zip_longest

from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import TensorType, arange, specify_shape
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index, IndexUpdate, index
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.type import XTensorType


def to_basic_idx(idx):
if isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
return idx.data

Check warning on line 17 in pytensor/xtensor/rewriting/indexing.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/indexing.py#L17

Added line #L17 was not covered by tests
elif idx.owner:
# MakeSlice Op
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
return slice(
*[
None if isinstance(i.type, NoneTypeT) else i
for i in idx.owner.inputs
]
)
else:
return idx

Check warning on line 28 in pytensor/xtensor/rewriting/indexing.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/indexing.py#L28

Added line #L28 was not covered by tests
if (
isinstance(idx.type, XTensorType)
and idx.type.ndim == 0
and idx.type.dtype != bool
):
return idx.values
raise TypeError("Cannot convert idx to basic idx")

Check warning on line 35 in pytensor/xtensor/rewriting/indexing.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/indexing.py#L35

Added line #L35 was not covered by tests


def _lower_index(node):
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
xarray-like indexing has two modes:
1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices.
2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing.
An Index Op can combine both modes.
To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing.
We expand the dims of each index so they are as large as the number of output dimensions, place the indices that
belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes.
For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]],
This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have
indices that have more than one dimension at the start.
In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension).
However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices
we have to convert the slices to equivalent advanced indices.
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
and then indexing it with the original slice. This index is then handled as a regular advanced index.
Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy.
When all advanced indices are consecutive, the respective view is located in the "original" location.
However, if advanced indices are separated by basic indices (slices in our case), the output views
always show up at the front of the array. This information is returned as the second output of this function,
which labels the final position of the indexed dimensions under this rule.
"""

assert isinstance(node.op, Index)

x, *idxs = node.inputs
[out] = node.outputs
x_tensor_indexed_dims = out.type.dims
x_tensor = tensor_from_xtensor(x)

if all(
(
isinstance(idx.type, SliceType)
or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0)
)
for idx in idxs
):
# Special case having just basic indexing
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]

else:
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing
# May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index
x_dims = x.type.dims
x_shape = tuple(x.shape)
out_ndim = out.type.ndim
out_dims = out.type.dims
aligned_idxs = []
basic_idx_axis = []
# zip_longest adds the implicit slice(None)
for i, (idx, x_dim) in enumerate(
zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None)))
):
if isinstance(idx.type, SliceType):
if not any(
(
isinstance(other_idx.type, XTensorType)
and x_dim in other_idx.dims
)
for j, other_idx in enumerate(idxs)
if j != i
):
# We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs.append(idx)
basic_idx_axis.append(out_dims.index(x_dim))
else:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
# And align it so it interacts correctly with the other advanced indices
adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)]
ds_order = ["x"] * out_ndim
ds_order[out_dims.index(x_dim)] = 0
aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order))
else:
assert isinstance(idx.type, XTensorType)
if idx.type.ndim == 0:
# Scalar index, we can use it directly
aligned_idxs.append(idx.values)
else:
# Vector index, we need to align the indexing dimensions with the base_dims
ds_order = ["x"] * out_ndim
for j, idx_dim in enumerate(idx.dims):
ds_order[out_dims.index(idx_dim)] = j
aligned_idxs.append(idx.values.dimshuffle(ds_order))

# Squeeze indexing dimensions that were not used because we kept basic indexing slices
if basic_idx_axis:
aligned_idxs = [
idx.squeeze(axis=basic_idx_axis)
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
else idx
for idx in aligned_idxs
]

x_tensor_indexed = x_tensor[tuple(aligned_idxs)]

if basic_idx_axis and _non_consecutive_adv_indexing(aligned_idxs):
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
# We need to transpose them back to the expected output order
x_tensor_indexed_basic_dims = [out_dims[axis] for axis in basic_idx_axis]
x_tensor_indexed_dims = [
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims
] + x_tensor_indexed_basic_dims

return x_tensor_indexed, x_tensor_indexed_dims


@register_lower_xtensor
@node_rewriter(tracks=[Index])
def lower_index(fgraph, node):
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
The bulk of the work is done by `_lower_index`, except for special logic to control the
location of non-consecutive advanced indices, and to preserve static shape information.
"""

[out] = node.outputs
out_dims = out.type.dims

x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node)
if x_tensor_indexed_dims != out_dims:
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
# We need to transpose them back to the expected output order
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)

# Add lost shape information
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)

new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter(tracks=[IndexUpdate])
def lower_index_update(fgraph, node):
"""Lower XTensorVariable index update to regular TensorVariable indexing update.
This rewrite requires converting the index view to a tensor-based equivalent expression,
just like `lower_index`. It then requires aligning the dimensions of y with the
dimensions of the index view, with special care for non-consecutive dimensions being
pulled to the front axis according to numpy rules.
"""
x, y, *idxs = node.inputs

# Lower the indexing part first
indexed_node = index.make_node(x, *idxs)
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node)
y_tensor = tensor_from_xtensor(y)

# Align dimensions of y with those of the indexed tensor x
y_dims = y.type.dims
y_dims_set = set(y_dims)
y_order = tuple(
y_dims.index(x_dim) if x_dim in y_dims_set else "x"
for x_dim in x_tensor_indexed_dims
)
# Remove useless left expand_dims
while len(y_order) > 0 and y_order[0] == "x":
y_order = y_order[1:]
if y_order != tuple(range(y_tensor.type.ndim)):
y_tensor = y_tensor.dimshuffle(y_order)

x_tensor_updated = inc_subtensor(
x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set"
)
new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims)
return [new_out]
47 changes: 47 additions & 0 deletions pytensor/xtensor/rewriting/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from string import ascii_lowercase

from pytensor.graph import node_rewriter
from pytensor.tensor import einsum
from pytensor.tensor.shape import specify_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_lower_xtensor
@node_rewriter(tracks=[XDot])
def lower_dot(fgraph, node):
"""Rewrite XDot to tensor.dot.
This rewrite converts an XDot operation to a tensor-based dot operation,
handling dimension alignment and contraction.
"""
[x, y] = node.inputs
[out] = node.outputs

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
y_tensor = tensor_from_xtensor(y)

# Collect all dimension names across inputs and output
all_dims = list(
dict.fromkeys(x.type.dims + y.type.dims + out.type.dims)
) # preserve order
if len(all_dims) > len(ascii_lowercase):
raise ValueError("Too many dimensions to map to einsum subscripts")

Check warning on line 31 in pytensor/xtensor/rewriting/math.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/math.py#L31

Added line #L31 was not covered by tests

dim_to_char = dict(zip(all_dims, ascii_lowercase))

# Build einsum string
x_subs = "".join(dim_to_char[d] for d in x.type.dims)
y_subs = "".join(dim_to_char[d] for d in y.type.dims)
out_subs = "".join(dim_to_char[d] for d in out.type.dims)
einsum_str = f"{x_subs},{y_subs}->{out_subs}"

# Perform the einsum operation
out_tensor = einsum(einsum_str, x_tensor, y_tensor)

# Reshape to match the output shape
out_tensor = specify_shape(out_tensor, out.type.shape)

return [xtensor_from_tensor(out_tensor, out.type.dims)]
72 changes: 72 additions & 0 deletions pytensor/xtensor/rewriting/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from functools import partial

import pytensor.scalar as ps
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.extra_ops import CumOp
from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.reduction import XCumReduce, XReduce
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_lower_xtensor
@node_rewriter(tracks=[XReduce])
def lower_reduce(fgraph, node):
[x] = node.inputs
[out] = node.outputs
x_dims = x.type.dims
reduce_dims = node.op.dims
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]

if not reduce_axis:
return [x]

Check warning on line 22 in pytensor/xtensor/rewriting/reduction.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/reduction.py#L22

Added line #L22 was not covered by tests

match node.op.binary_op:
case ps.add:
tensor_op_class = Sum
case ps.mul:
tensor_op_class = Prod

Check warning on line 28 in pytensor/xtensor/rewriting/reduction.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/reduction.py#L28

Added line #L28 was not covered by tests
case ps.and_:
tensor_op_class = All
case ps.or_:
tensor_op_class = Any
case ps.scalar_maximum:
tensor_op_class = Max
case ps.scalar_minimum:
tensor_op_class = Min
case _:

Check warning on line 37 in pytensor/xtensor/rewriting/reduction.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/reduction.py#L37

Added line #L37 was not covered by tests
# Case without known/predefined Ops
tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op)

Check warning on line 39 in pytensor/xtensor/rewriting/reduction.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/reduction.py#L39

Added line #L39 was not covered by tests

x_tensor = tensor_from_xtensor(x)
out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor)
new_out = xtensor_from_tensor(out_tensor, out.type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter(tracks=[XCumReduce])
def lower_cumreduce(fgraph, node):
[x] = node.inputs
x_dims = x.type.dims
reduce_dims = node.op.dims
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]

if not reduce_axis:
return [x]

Check warning on line 56 in pytensor/xtensor/rewriting/reduction.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/reduction.py#L56

Added line #L56 was not covered by tests

match node.op.binary_op:
case ps.add:
tensor_op_class = partial(CumOp, mode="add")
case ps.mul:
tensor_op_class = partial(CumOp, mode="mul")
case _:

Check warning on line 63 in pytensor/xtensor/rewriting/reduction.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/reduction.py#L63

Added line #L63 was not covered by tests
# We don't know how to convert an arbitrary binary cum/reduce Op
return None

Check warning on line 65 in pytensor/xtensor/rewriting/reduction.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/reduction.py#L65

Added line #L65 was not covered by tests

# Each dim corresponds to an application of Cumsum/Cumprod
out_tensor = tensor_from_xtensor(x)
for axis in reduce_axis:
out_tensor = tensor_op_class(axis=axis)(out_tensor)
out = xtensor_from_tensor(out_tensor, x.type.dims)
return [out]
166 changes: 166 additions & 0 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
expand_dims,
join,
moveaxis,
specify_shape,
squeeze,
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
UnStack,
)


@register_lower_xtensor
@node_rewriter(tracks=[Stack])
def lower_stack(fgraph, node):
[x] = node.inputs
batch_ndim = x.type.ndim - len(node.op.stacked_dims)
stacked_axes = [
i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims
]
end = tuple(range(-len(stacked_axes), 0))

x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end)
if batch_ndim == (x.type.ndim - 1):
# This happens when we stack a "single" dimension, in this case all we need is the transpose
# Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename
final_tensor = x_tensor_transposed
else:
final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1)
final_tensor = x_tensor_transposed.reshape(final_shape)

new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter(tracks=[UnStack])
def lower_unstack(fgraph, node):
x = node.inputs[0]
unstacked_lengths = node.inputs[1:]
axis_to_unstack = x.type.dims.index(node.op.old_dim_name)

x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1])
final_tensor = x_tensor_transposed.reshape(
(*x_tensor_transposed.shape[:-1], *unstacked_lengths)
)
# Reintroduce any static shape information that was lost during the reshape
final_tensor = specify_shape(final_tensor, node.outputs[0].type.shape)

new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter(tracks=[Concat])
def lower_concat(fgraph, node):
out_dims = node.outputs[0].type.dims
concat_dim = node.op.dim
concat_axis = out_dims.index(concat_dim)

# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)

# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
for tensor_inp in tensor_inputs:
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == concat_axis or non_concat_shape[i] is not None:
continue
non_concat_shape[i] = sh

assert non_concat_shape.count(None) == 1

bcast_tensor_inputs = []
for tensor_inp in tensor_inputs:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))

joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]


@register_lower_xtensor
@node_rewriter(tracks=[Transpose])
def lower_transpose(fgraph, node):
[x] = node.inputs
# Use the final dimensions that were already computed in make_node
out_dims = node.outputs[0].type.dims
in_dims = x.type.dims

# Compute the permutation based on the final dimensions
perm = tuple(in_dims.index(d) for d in out_dims)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([Squeeze])
def lower_squeeze(fgraph, node):
"""Rewrite Squeeze to tensor.squeeze."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
x_dims = x.type.dims
dims_to_remove = node.op.dims
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([ExpandDims])
def lower_expand_dims(fgraph, node):
"""Rewrite ExpandDims using tensor operations."""
x, size = node.inputs
out = node.outputs[0]

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
size_tensor = tensor_from_xtensor(size)

# Get the new dimension name and position
new_axis = 0 # Always insert at front

# Use tensor operations
if out.type.shape[0] == 1:
# Simple case: just expand with size 1
result_tensor = expand_dims(x_tensor, new_axis)

Check warning on line 156 in pytensor/xtensor/rewriting/shape.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/shape.py#L156

Added line #L156 was not covered by tests
else:
# Otherwise broadcast to the requested size
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))

# Preserve static shape information
result_tensor = specify_shape(result_tensor, out.type.shape)

# Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result]
41 changes: 41 additions & 0 deletions pytensor/xtensor/rewriting/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase


lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)

optdb.register(
"lower_xtensor",
lower_xtensor_db,
"fast_run",
"fast_compile",
"minimum_compile",
position=0.1,
)


def register_lower_xtensor(
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
):
if isinstance(node_rewriter, str):

def register(inner_rewriter: RewriteDatabase | NodeRewriter):
return register_lower_xtensor(

Check warning on line 24 in pytensor/xtensor/rewriting/utils.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/utils.py#L23-L24

Added lines #L23 - L24 were not covered by tests
inner_rewriter, node_rewriter, *tags, **kwargs
)

return register

Check warning on line 28 in pytensor/xtensor/rewriting/utils.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/utils.py#L28

Added line #L28 was not covered by tests

else:
name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore
lower_xtensor_db.register(
name,
node_rewriter,
"fast_run",
"fast_compile",
"minimum_compile",
*tags,
**kwargs,
)
return node_rewriter
123 changes: 123 additions & 0 deletions pytensor/xtensor/rewriting/vectorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from pytensor.graph import node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import compute_batch_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise


@register_lower_xtensor
@node_rewriter(tracks=[XElemwise])
def lower_elemwise(fgraph, node):
out_dims = node.outputs[0].type.dims

# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)

tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
*tensor_inputs, return_list=True
)

# Convert output Tensors to XTensors
new_outs = [
xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs
]
return new_outs


@register_lower_xtensor
@node_rewriter(tracks=[XBlockwise])
def lower_blockwise(fgraph, node):
op: XBlockwise = node.op
batch_ndim = node.outputs[0].type.ndim - len(op.core_dims[1][0])
batch_dims = node.outputs[0].type.dims[:batch_ndim]

# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_inputs = []
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_inputs.append(tensor_inp)

signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
if signature is None:
# Build a signature based on the core dimensions
# The Op signature could be more strict, as core_dims will never be repeated, but no functionality depends greatly on it
inputs_core_dims, outputs_core_dims = op.core_dims
inputs_signature = ",".join(

Check warning on line 62 in pytensor/xtensor/rewriting/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/vectorization.py#L61-L62

Added lines #L61 - L62 were not covered by tests
f"({', '.join(inp_core_dims)})" for inp_core_dims in inputs_core_dims
)
outputs_signature = ",".join(

Check warning on line 65 in pytensor/xtensor/rewriting/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/vectorization.py#L65

Added line #L65 was not covered by tests
f"({', '.join(out_core_dims)})" for out_core_dims in outputs_core_dims
)
signature = f"{inputs_signature}->{outputs_signature}"

Check warning on line 68 in pytensor/xtensor/rewriting/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/rewriting/vectorization.py#L68

Added line #L68 was not covered by tests
tensor_op = Blockwise(core_op=op.core_op, signature=signature)
tensor_outs = tensor_op(*tensor_inputs, return_list=True)

# Convert output Tensors to XTensors
new_outs = [
xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
]
return new_outs


@register_lower_xtensor
@node_rewriter(tracks=[XRV])
def lower_rv(fgraph, node):
op: XRV = node.op
core_op = op.core_op

_, old_out = node.outputs
rng, *extra_dim_lengths_and_params = node.inputs
extra_dim_lengths = extra_dim_lengths_and_params[: len(op.extra_dims)]
params = extra_dim_lengths_and_params[len(op.extra_dims) :]

batch_ndim = old_out.type.ndim - len(op.core_dims[1])
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]

# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params = []
for inp, core_dims in zip(params, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in param_batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_params.append(tensor_inp)

size = None
if op.extra_dims:
# RV size contains the lengths of all batch dimensions, including those coming from the parameters
if tensor_params:
param_batch_shape = tuple(
compute_batch_shape(tensor_params, ndims_params=core_op.ndims_params)
)
else:
param_batch_shape = ()
size = [*extra_dim_lengths, *param_batch_shape]

# RVs are their own core Op
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs

# Convert output Tensors to XTensors
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
return [new_next_rng, new_out]
500 changes: 500 additions & 0 deletions pytensor/xtensor/shape.py

Large diffs are not rendered by default.

866 changes: 866 additions & 0 deletions pytensor/xtensor/type.py

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions pytensor/xtensor/vectorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from itertools import chain

import numpy as np

from pytensor import scalar as ps
from pytensor import shared
from pytensor.graph import Apply, Op
from pytensor.scalar import discrete_dtypes
from pytensor.tensor import tensor
from pytensor.tensor.random.op import RNGConsumerOp
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.utils import (
get_static_shape_from_size_variables,
)
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor


def combine_dims_and_shape(inputs):
dims_and_shape: dict[str, int | None] = {}
for inp in inputs:
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
if dim not in dims_and_shape:
dims_and_shape[dim] = dim_length
elif dim_length is not None:
# Check for conflicting shapes
if (dims_and_shape[dim] is not None) and (
dims_and_shape[dim] != dim_length
):
raise ValueError(f"Dimension {dim} has conflicting shapes")

Check warning on line 30 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L30

Added line #L30 was not covered by tests
# Keep the non-None shape
dims_and_shape[dim] = dim_length
return dims_and_shape


class XElemwise(XOp):
__props__ = ("scalar_op",)

def __init__(self, scalar_op):
super().__init__()
self.scalar_op = scalar_op

def make_node(self, *inputs):
inputs = [as_xtensor(inp) for inp in inputs]
if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin):
raise ValueError(

Check warning on line 46 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L46

Added line #L46 was not covered by tests
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
)

dims_and_shape = combine_dims_and_shape(inputs)
if dims_and_shape:
output_dims, output_shape = zip(*dims_and_shape.items())
else:
output_dims, output_shape = (), ()

dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs]
output_dtypes = [
out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs
]
outputs = [
xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape)
for output_dtype in output_dtypes
]
return Apply(self, inputs, outputs)


class XBlockwise(XOp):
__props__ = ("core_op", "core_dims")

def __init__(
self,
core_op: Op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]],
signature: str | None = None,
):
super().__init__()
self.core_op = core_op
self.core_dims = core_dims
self.signature = signature # Only used for lowering, not for validation

def make_node(self, *inputs):
inputs = [as_xtensor(i) for i in inputs]
if len(inputs) != len(self.core_dims[0]):
raise ValueError(

Check warning on line 84 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L84

Added line #L84 was not covered by tests
f"Wrong number of inputs, expected {len(self.core_dims[0])}, got {len(inputs)}"
)

dims_and_shape = combine_dims_and_shape(inputs)

core_inputs_dims, core_outputs_dims = self.core_dims
core_input_dims_set = set(chain.from_iterable(core_inputs_dims))
batch_dims, batch_shape = zip(
*((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set)
)

dummy_core_inputs = []
for inp, core_inp_dims in zip(inputs, core_inputs_dims):
try:
core_static_shape = [
inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims
]
except IndexError:
raise ValueError(

Check warning on line 103 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L102-L103

Added lines #L102 - L103 were not covered by tests
f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}"
)
dummy_core_inputs.append(
tensor(dtype=inp.type.dtype, shape=core_static_shape)
)
core_node = self.core_op.make_node(*dummy_core_inputs)

outputs = [
xtensor(
dtype=core_out.type.dtype,
shape=batch_shape + core_out.type.shape,
dims=batch_dims + core_out_dims,
)
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims)
]
return Apply(self, inputs, outputs)


class XRV(XOp, RNGConsumerOp):
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
Xarray does not offer random generators, so this class implements a new API.
It mostly works like a gufunc (or XBlockwise), which specifies core dimensions for inputs and output, and
enforces dim-based broadcasting between inputs and output.
It differs from XBlockwise in a couple of ways:
1. It is restricted to one sample output
2. It takes a random generator as the first input and returns the consumed generator as the first output.
3. It has the concept of extra dimensions, which determine extra batch dimensions of the output, that are not
implied by batch dimensions of the parameters.
"""

default_output = 1
__props__ = ("core_op", "core_dims", "extra_dims")

def __init__(
self,
core_op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]],
extra_dims: tuple[str, ...],
):
super().__init__()
self.core_op = core_op
inps_core_dims, out_core_dims = core_dims
for operand_dims in (*inps_core_dims, out_core_dims):
if len(set(operand_dims)) != len(operand_dims):
raise ValueError(f"Operand has repeated dims {operand_dims}")
self.core_dims = (tuple(i for i in inps_core_dims), tuple(out_core_dims))
if len(set(extra_dims)) != len(extra_dims):
raise ValueError("size_dims must be unique")

Check warning on line 154 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L154

Added line #L154 was not covered by tests
self.extra_dims = tuple(extra_dims)

def update(self, node):
# RNG input and update are the first input and output respectively
return {node.inputs[0]: node.outputs[0]}

Check warning on line 159 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L159

Added line #L159 was not covered by tests

def make_node(self, rng, *extra_dim_lengths_and_params):
if rng is None:
rng = shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(

Check warning on line 165 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L165

Added line #L165 was not covered by tests
"The type of rng should be an instance of RandomGeneratorType "
)

extra_dim_lengths = [
as_xtensor(dim_length).values
for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)]
]
if not all(
(dim_length.type.ndim == 0 and dim_length.type.dtype in discrete_dtypes)
for dim_length in extra_dim_lengths
):
raise TypeError("All dimension lengths should be scalar discrete dtype.")

Check warning on line 177 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L177

Added line #L177 was not covered by tests

params = [
as_xtensor(param)
for param in extra_dim_lengths_and_params[len(self.extra_dims) :]
]
if len(params) != len(self.core_op.ndims_params):
raise ValueError(

Check warning on line 184 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L184

Added line #L184 was not covered by tests
f"Expected {len(self.core_op.ndims_params)} parameters + {len(self.extra_dims)} dim_lengths, "
f"got {len(extra_dim_lengths_and_params)}"
)

param_core_dims, output_core_dims = self.core_dims
input_core_dims_set = set(chain.from_iterable(param_core_dims))

# Check parameters don't have core dimensions they shouldn't have
for param, core_param_dims in zip(params, param_core_dims):
if invalid_core_dims := (
set(param.type.dims) - set(core_param_dims)
).intersection(input_core_dims_set):
raise ValueError(
f"Parameter {param} has invalid core dimensions {sorted(invalid_core_dims)}"
)

extra_dims_and_shape = dict(
zip(
self.extra_dims, get_static_shape_from_size_variables(extra_dim_lengths)
)
)
params_dims_and_shape = combine_dims_and_shape(params)

# Check that no parameter dims conflict with size dims
if conflict_dims := set(extra_dims_and_shape).intersection(
params_dims_and_shape
):
raise ValueError(
f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
)

batch_dims_and_shape = [
(dim, dim_length)
for dim, dim_length in (
extra_dims_and_shape | params_dims_and_shape
).items()
if dim not in input_core_dims_set
]
if batch_dims_and_shape:
batch_output_dims, batch_output_shape = zip(*batch_dims_and_shape)
else:
batch_output_dims, batch_output_shape = (), ()

dummy_core_inputs = []
for param, core_param_dims in zip(params, param_core_dims):
try:
core_static_shape = [
param.type.shape[param.type.dims.index(d)] for d in core_param_dims
]
except ValueError:
raise ValueError(
f"At least one core dim={core_param_dims} missing from input {param} with dims={param.type.dims}"
)
dummy_core_inputs.append(
tensor(dtype=param.type.dtype, shape=core_static_shape)
)
core_node = self.core_op.make_node(rng, None, *dummy_core_inputs)

if not len(core_node.outputs) == 2:
raise NotImplementedError(

Check warning on line 244 in pytensor/xtensor/vectorization.py

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L244

Added line #L244 was not covered by tests
"XRandomVariable only supports core ops with two outputs (rng, out)"
)

_, core_out = core_node.outputs
out = xtensor(
dtype=core_out.type.dtype,
shape=batch_output_shape + core_out.type.shape,
dims=batch_output_dims + output_core_dims,
)

return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out])
2 changes: 1 addition & 1 deletion tests/tensor/random/rewriting/test_basic.py
Original file line number Diff line number Diff line change
@@ -950,7 +950,7 @@ def test_Dimshuffle_lift_restrictions():
1e-7,
),
(
(0, 1, 2),
(0, 2, 1),
True,
normal,
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),
2 changes: 1 addition & 1 deletion tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
@@ -148,7 +148,7 @@ def test_recursive_lift(self):

def test_useless_dimshuffle(self):
x, *_ = inputs()
e = ds(x, (0, 1))
e = DimShuffle(new_order=(0, 1), input_ndim=2)(x)
g = FunctionGraph([x], [e], clone=False)
assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
Empty file added tests/xtensor/__init__.py
Empty file.
512 changes: 512 additions & 0 deletions tests/xtensor/test_indexing.py

Large diffs are not rendered by default.

76 changes: 76 additions & 0 deletions tests/xtensor/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# ruff: noqa: E402
import pytest


pytest.importorskip("xarray")
pytest.importorskip("xarray_einstats")

import numpy as np
from xarray import DataArray
from xarray_einstats.linalg import (
cholesky as xr_cholesky,
)
from xarray_einstats.linalg import (
solve as xr_solve,
)

from pytensor.xtensor.linalg import cholesky, solve
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function


def test_cholesky():
x = xtensor("x", dims=("a", "batch", "b"), shape=(4, 3, 4))
y = cholesky(x, dims=["b", "a"])
assert y.type.dims == ("batch", "b", "a")
assert y.type.shape == (3, 4, 4)

fn = xr_function([x], y)
rng = np.random.default_rng(25)
x_ = rng.random(size=(3, 4, 4))
x_ = x_ @ x_.mT
x_test = DataArray(x_.transpose(1, 0, 2), dims=x.type.dims)
xr_assert_allclose(
fn(x_test),
xr_cholesky(x_test, dims=["b", "a"]),
)


def test_solve_vector_b():
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1))
b = xtensor("b", dims=("city", "planet"), shape=(None, 2))
x = solve(a, b, dims=["country", "city"])
assert x.type.dims == ("galaxy", "planet", "country")
# Core Solve doesn't make use of the fact A must be square in the static shape
assert x.type.shape == (1, 2, None)

fn = xr_function([a, b], x)

rng = np.random.default_rng(25)
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims)
b_test = DataArray(rng.random(size=(4, 2)), dims=b.type.dims)

xr_assert_allclose(
fn(a_test, b_test),
xr_solve(a_test, b_test, dims=["country", "city"]),
)


def test_solve_matrix_b():
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1))
b = xtensor("b", dims=("district", "city", "planet"), shape=(5, None, 2))
x = solve(a, b, dims=["country", "city", "district"])
assert x.type.dims == ("galaxy", "planet", "country", "district")
# Core Solve doesn't make use of the fact A must be square in the static shape
assert x.type.shape == (1, 2, None, 5)

fn = xr_function([a, b], x)

rng = np.random.default_rng(25)
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims)
b_test = DataArray(rng.random(size=(5, 4, 2)), dims=b.type.dims)

xr_assert_allclose(
fn(a_test, b_test),
xr_solve(a_test, b_test, dims=["country", "city", "district"]),
)
316 changes: 316 additions & 0 deletions tests/xtensor/test_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
# ruff: noqa: E402
import pytest


pytest.importorskip("xarray")

import inspect

import numpy as np
from xarray import DataArray

import pytensor.scalar as ps
import pytensor.xtensor.math as pxm
from pytensor import function
from pytensor.scalar import ScalarOp
from pytensor.xtensor.basic import rename
from pytensor.xtensor.math import add, exp
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function


def test_all_scalar_ops_are_wrapped():
# This ignores wrapper functions
pxm_members = {name for name, _ in inspect.getmembers(pxm)}
for name, op in inspect.getmembers(ps):
if name in {
"complex_from_polar",
"inclosedrange",
"inopenrange",
"round_half_away_from_zero",
"round_half_to_even",
"scalar_abs",
"scalar_maximum",
"scalar_minimum",
} or name.startswith("convert_to_"):
# These are not regular numpy functions or are unusual alias
continue
if isinstance(op, ScalarOp) and name not in pxm_members:
raise NotImplementedError(f"ScalarOp {name} not wrapped in xtensor.math")


def test_scalar_case():
x = xtensor("x", dims=(), shape=())
y = xtensor("y", dims=(), shape=())
out = add(x, y)

fn = function([x, y], out)

x_test = DataArray(2.0, dims=())
y_test = DataArray(3.0, dims=())
np.testing.assert_allclose(fn(x_test.values, y_test.values), 5.0)


def test_dimension_alignment():
x = xtensor("x", dims=("city", "country", "planet"), shape=(2, 3, 4))
y = xtensor(
"y",
dims=("galaxy", "country", "city"),
shape=(5, 3, 2),
)
z = xtensor("z", dims=("universe",), shape=(1,))
out = add(x, y, z)
assert out.type.dims == ("city", "country", "planet", "galaxy", "universe")

fn = function([x, y, z], out)

rng = np.random.default_rng(41)
test_x, test_y, test_z = (
DataArray(rng.normal(size=inp.type.shape), dims=inp.type.dims)
for inp in [x, y, z]
)
np.testing.assert_allclose(
fn(test_x.values, test_y.values, test_z.values),
(test_x + test_y + test_z).values,
)


def test_renamed_dimension_alignment():
x = xtensor("x", dims=("a", "b1", "b2"), shape=(2, 3, 3))
y = rename(x, b1="b2", b2="b1")
z = rename(x, b2="b3")
assert y.type.dims == ("a", "b2", "b1")
assert z.type.dims == ("a", "b1", "b3")

out1 = add(x, x) # self addition
assert out1.type.dims == ("a", "b1", "b2")
out2 = add(x, y) # transposed addition
assert out2.type.dims == ("a", "b1", "b2")
out3 = add(x, z) # outer addition
assert out3.type.dims == ("a", "b1", "b2", "b3")

fn = xr_function([x], [out1, out2, out3])
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
results = fn(x_test)
expected_results = [
x_test + x_test,
x_test + x_test.rename(b1="b2", b2="b1"),
x_test + x_test.rename(b2="b3"),
]
for result, expected_result in zip(results, expected_results):
xr_assert_allclose(result, expected_result)


def test_chained_operations():
x = xtensor("x", dims=("city",), shape=(None,))
y = xtensor("y", dims=("country",), shape=(4,))
z = add(exp(x), exp(y))
assert z.type.dims == ("city", "country")
assert z.type.shape == (None, 4)

fn = function([x, y], z)

x_test = DataArray(np.zeros(3), dims="city")
y_test = DataArray(np.ones(4), dims="country")

np.testing.assert_allclose(
fn(x_test.values, y_test.values),
(np.exp(x_test) + np.exp(y_test)).values,
)


def test_multiple_constant():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
out = exp(x * 2) + 2

fn = function([x], out)

x_test = np.zeros((2, 3), dtype=x.type.dtype)
res = fn(x_test)
expected_res = np.exp(x_test * 2) + 2
np.testing.assert_allclose(res, expected_res)


def test_cast():
x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32")
yf64 = x.astype("float64")
yi16 = x.astype("int16")
ybool = x.astype("bool")

fn = xr_function([x], [yf64, yi16, ybool])
x_test = xr_arange_like(x)
res_f64, res_i16, res_bool = fn(x_test)
xr_assert_allclose(res_f64, x_test.astype("float64"))
xr_assert_allclose(res_i16, x_test.astype("int16"))
xr_assert_allclose(res_bool, x_test.astype("bool"))

yc64 = x.astype("complex64")
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
yc64.astype("float64")


def test_dot():
"""Test basic dot product operations."""
# Test matrix-vector dot product (with multiple-letter dim names)
x = xtensor("x", dims=("aa", "bb"), shape=(2, 3))
y = xtensor("y", dims=("bb",), shape=(3,))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb"))
y_test = DataArray(np.ones(3), dims=("bb",))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Test matrix-vector dot product with ellipsis
z = x.dot(y, dim=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b"))
y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with string dim
z = x.dot(y, dim="b")
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim="b")
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with list of dims
z = x.dot(y, dim=["b"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=["b"])
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with ellipsis
z = x.dot(y, dim=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Test a case where there are two dimensions to sum over
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Same but with explicit dimensions
z = x.dot(y, dim=["b", "c"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=["b", "c"])
xr_assert_allclose(z_test, expected)

# Same but with ellipses
z = x.dot(y, dim=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Dot product with sum
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
expected = x_test.dot(y_test, dim=("a", "b", "c"))

x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
z = x.dot(y, dim=("a", "b", "c"))
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Dot product with sum in the middle
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
expected = x_test.dot(y_test, dim=("b", "d"))
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5))
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6))
z = x.dot(y, dim=("b", "d"))
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with first two dims
expected = x_test.dot(y_test, dim=["a", "b"])
z = x.dot(y, dim=["a", "b"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with last two
expected = x_test.dot(y_test, dim=["d", "e"])
z = x.dot(y, dim=["d", "e"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with every other dim
expected = x_test.dot(y_test, dim=["a", "c", "e"])
z = x.dot(y, dim=["a", "c", "e"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Test symbolic shapes
x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic
y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic
z = x.dot(y)
fn = xr_function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((3, 4)), dims=("b", "c"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)


def test_dot_errors():
# No matching dimensions
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
with pytest.raises(ValueError, match="Dimension e not found in either input"):
x.dot(y, dim="e")

# Concrete dimension size mismatches
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
with pytest.raises(
ValueError,
match="Size of dim 'b' does not match",
):
x.dot(y)

# Symbolic dimension size mismatches
x = xtensor("x", dims=("a", "b"), shape=(2, None))
y = xtensor("y", dims=("b", "c"), shape=(None, 5))
z = x.dot(y)
fn = xr_function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
# Doesn't fail until the rewrite
with pytest.raises(ValueError, match="not aligned"):
fn(x_test, y_test)
422 changes: 422 additions & 0 deletions tests/xtensor/test_random.py

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions tests/xtensor/test_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# ruff: noqa: E402
import pytest


pytest.importorskip("xarray")

from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function


@pytest.mark.parametrize(
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
)
@pytest.mark.parametrize(
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:]
)
def test_reduction(method, dim):
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
out = getattr(x, method)(dim=dim)

fn = xr_function([x], out)
x_test = xr_arange_like(x)

xr_assert_allclose(
fn(x_test),
getattr(x_test, method)(dim=dim),
)
468 changes: 468 additions & 0 deletions tests/xtensor/test_shape.py

Large diffs are not rendered by default.

119 changes: 119 additions & 0 deletions tests/xtensor/test_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# ruff: noqa: E402
import pytest


pytest.importorskip("xarray")

import numpy as np
from xarray import DataArray

from pytensor.graph.basic import equal_computations
from pytensor.tensor import as_tensor, specify_shape, tensor
from pytensor.xtensor import xtensor
from pytensor.xtensor.type import XTensorType, as_xtensor


def test_xtensortype():
x1 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
x2 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
x3 = XTensorType(dtype="float64", dims=("a", "b"), shape=(None, 3))
y1 = XTensorType(dtype="float64", dims=("c", "d"), shape=(4, 5))
z1 = XTensorType(dtype="float32", dims=("a", "b"), shape=(2, 3))

assert x1 == x2 and x1.is_super(x2) and x2.is_super(x1)
assert x1 != x3 and not x1.is_super(x3) and x3.is_super(x1)
assert x1 != y1 and not x1.is_super(y1) and not y1.is_super(x1)
assert x1 != z1 and not x1.is_super(z1) and not z1.is_super(x1)


def test_xtensortype_filter_variable():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))

y1 = xtensor("y1", dims=("a", "b"), shape=(2, 3))
assert x.type.filter_variable(y1) is y1

y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2))
expected_y2 = y2.transpose()
assert equal_computations([x.type.filter_variable(y2)], [expected_y2])

y3 = xtensor("y3", dims=("b", "a"), shape=(3, None))
expected_y3 = as_xtensor(
specify_shape(y3.transpose().values, (2, 3)), dims=("a", "b")
)
assert equal_computations([x.type.filter_variable(y3)], [expected_y3])

# Cases that fail
with pytest.raises(TypeError):
y4 = xtensor("y4", dims=("a", "b"), shape=(3, 2))
x.type.filter_variable(y4)

with pytest.raises(TypeError):
y5 = xtensor("y5", dims=("a", "c"), shape=(2, 3))
x.type.filter_variable(y5)

with pytest.raises(TypeError):
y6 = xtensor("y6", dims=("a", "b", "c"), shape=(2, 3, 4))
x.type.filter_variable(y6)

with pytest.raises(TypeError):
y7 = xtensor("y7", dims=("a", "b"), shape=(2, 3), dtype="int32")
x.type.filter_variable(y7)

z1 = tensor("z1", shape=(2, None))
expected_z1 = as_xtensor(specify_shape(z1, (2, 3)), dims=("a", "b"))
assert equal_computations([x.type.filter_variable(z1)], [expected_z1])

# Cases that fail
with pytest.raises(TypeError):
z2 = tensor("z2", shape=(3, 2))
x.type.filter_variable(z2)

with pytest.raises(TypeError):
z3 = tensor("z3", shape=(1, 2, 3))
x.type.filter_variable(z3)

with pytest.raises(TypeError):
z4 = tensor("z4", shape=(2, 3), dtype="int32")
x.type.filter_variable(z4)


def test_xtensor_constant():
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b")))
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))

y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type
assert x.signature() == y.signature()
assert x.equals(y)
x_eval = x.eval()
assert isinstance(x.eval(), np.ndarray)
np.testing.assert_array_equal(x_eval, y.eval(), strict=True)

z = as_xtensor(np.ones((3, 2)), dims=("b", "a"))
assert z.type != x.type
assert z.signature() != x.signature()
assert not x.equals(z)
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)


def test_as_tensor():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))

with pytest.raises(
TypeError,
match="PyTensor forbids automatic conversion of XTensorVariable to TensorVariable",
):
as_tensor(x)

x_pt = as_tensor(x, allow_xtensor_conversion=True)
assert equal_computations([x_pt], [x.values])


def test_minimum_compile():
from pytensor.compile.mode import Mode

x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = x.transpose()
minimum_mode = Mode(linker="py", optimizer="minimum_compile")
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode)
np.testing.assert_array_equal(result, np.ones((3, 2)))
60 changes: 60 additions & 0 deletions tests/xtensor/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# ruff: noqa: E402
import pytest


pytest.importorskip("xarray")

import numpy as np
from xarray import DataArray
from xarray.testing import assert_allclose

from pytensor import function
from pytensor.xtensor.type import XTensorType


def xr_function(*args, **kwargs):
"""Compile and wrap a PyTensor function to return xarray DataArrays."""
fn = function(*args, **kwargs)
symbolic_outputs = fn.maker.fgraph.outputs
assert all(
isinstance(out.type, XTensorType) for out in symbolic_outputs
), "All outputs must be xtensor"

def xfn(*xr_inputs):
np_inputs = [
inp.values if isinstance(inp, DataArray) else inp for inp in xr_inputs
]
np_outputs = fn(*np_inputs)
if not isinstance(np_outputs, tuple | list):
return DataArray(np_outputs, dims=symbolic_outputs[0].type.dims)
else:
return tuple(
DataArray(res, dims=out.type.dims)
for res, out in zip(np_outputs, symbolic_outputs)
)

xfn.fn = fn
return xfn


def xr_assert_allclose(x, y, *args, **kwargs):
# Assert that two xarray DataArrays are close, ignoring coordinates
x = x.drop_vars(x.coords)
y = y.drop_vars(y.coords)
assert_allclose(x, y, *args, **kwargs)


def xr_arange_like(x):
return DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)


def xr_random_like(x, rng=None):
if rng is None:
rng = np.random.default_rng()

return DataArray(
rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims
)