diff --git a/doc/internal/named-dims.ipynb b/doc/internal/named-dims.ipynb new file mode 100644 index 0000000000..2d2027cc79 --- /dev/null +++ b/doc/internal/named-dims.ipynb @@ -0,0 +1,412 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8e7ecdb3-7df3-49a7-bab5-f63455d43581", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4319db8d-a559-473d-a298-6b6df67db50b", + "metadata": {}, + "outputs": [], + "source": [ + "class Subset:\n", + " pass\n", + "\n", + "\n", + "class Slice(Subset):\n", + " slice: slice\n", + "\n", + "\n", + "class IndexSet(Subset):\n", + " values: [int]\n", + "\n", + "\n", + "class Dynamic(Subset):\n", + " pass\n", + "\n", + "\n", + "class SubsetType:\n", + " base: DimType\n", + " subset: list[Subset]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea87cffe-d576-4b0f-8ddc-f26593a361e6", + "metadata": {}, + "outputs": [], + "source": [ + "x_sub, x_sub2 = px.project_to_shared_dim(x[:-1], x[1:], dim=foo)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5093c6c2-2e5d-4215-8d6a-c519f800d93b", + "metadata": {}, + "outputs": [], + "source": [ + "foo = px.dim()\n", + "x = px.xtensor(foo)\n", + "\n", + "# Example 1\n", + "x[:-1].pad_to_dim(foo, fill_value=0.0)\n", + "\n", + "# example 2\n", + "left_part, right_part = foo.intersect_and_align(x[:-1], x[1:])\n", + "left_part + right_part\n", + "\n", + "# example 3\n", + "x[1:] + x[:-1].with_dims_like(x[1:])" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "8b73d0c8-69f8-4752-8d9c-ca758dfa82a1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/adr/git/pytensor/pytensor/xtensor/__init__.py:18: UserWarning: xtensor module is experimental and full of bugs\n", + " warnings.warn(\"xtensor module is experimental and full of bugs\")\n" + ] + } + ], + "source": [ + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import pytensor.xtensor as px\n", + "import pytest" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "b62c5cd7-9e4d-4ac3-9f71-23dec2d7ae89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BasicDim(foo, uuid=?)" + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo = px.dim(\"foo\")\n", + "foo.type" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1de5f25f-d09a-48ca-b49e-53b0649a76f2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array(5.)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo = px.dim(\"foo\")\n", + "x = px.ones(foo, name=\"x\")\n", + "func = pytensor.function([foo], [x.sum(foo)], mode=\"FAST_RUN\")\n", + "func(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "be07b1b5-2b49-480f-a66b-4d6bf426b1da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array(0.)]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo = px.dim(\"foo\")\n", + "x = px.ones(foo, name=\"x\")\n", + "func = pytensor.function([foo], [x.std(foo)], mode=\"FAST_RUN\")\n", + "func(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "abf9055e-c0be-4bb8-aec0-f8f723f387ae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FromLength{dim_type=BasicDim(bar, uuid=?)} [id A] 'bar'\n", + " └─ TensorFromScalar [id B]\n", + " └─ [id C]\n" + ] + } + ], + "source": [ + "length = pytensor.scalar.basic.int64()\n", + "bar = px.dim(\"bar\", size=length)\n", + "bar.dprint();" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ca970d40-fd30-49a7-a1d2-a2a32281b43d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clone{dim_type=CloneDim(bar2, base=BasicDim(bar, uuid=?), uuid=?)} [id A] 'bar2'\n", + " └─ bar [id B]\n" + ] + } + ], + "source": [ + "bar = px.dim(\"bar\")\n", + "bar2 = px.dim(\"bar2\", size=bar)\n", + "# same as bar.clone_dim(\"bar2\")\n", + "bar2.dprint();" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a6f00cd4-d4fe-4889-a33b-8c7a2e6c582a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(foo, bar)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo = px.dim(\"foo\")\n", + "bar = foo.clone_dim(\"bar\")\n", + "x = px.xtensor(\"x\", dims=[foo, bar])\n", + "x.dims" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2ae2b959-f64e-4739-bab5-61840a333471", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.]])]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo = px.dim(\"foo\")\n", + "bar = px.dim(\"bar\")\n", + "z1 = px.ones(foo)\n", + "z2 = px.ones(bar)\n", + "func = pytensor.function([foo, bar], [z1 + z2])\n", + "func(3, 5)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "23bb0380-5fcf-4020-97f9-073edfa72719", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b09ae469-ab34-402c-b13c-8f52cf140249", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "XTensorFromTensor [id A]\n", + " ├─ SpecifyShape [id B]\n", + " │ ├─ [id C]\n", + " │ ├─ Length [id D]\n", + " │ │ └─ foo [id E]\n", + " │ └─ Length [id F]\n", + " │ └─ bar [id G]\n", + " ├─ foo [id E]\n", + " └─ bar [id G]\n" + ] + }, + { + "data": { + "text/plain": [ + "[array([[0., 0., 0., 0.],\n", + " [0., 0., 0., 0.],\n", + " [0., 0., 0., 0.]])]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo = px.dim(\"foo\")\n", + "bar = px.dim(\"bar\")\n", + "z_tensor = pt.matrix()\n", + "z_new = px.xtensor_from_tensor(z_tensor, [foo, bar], check=True)\n", + "z_new.dprint()\n", + "func = pytensor.function([z_tensor, foo, bar], [z_new])\n", + "with pytest.raises(AssertionError):\n", + " func(np.zeros((3, 4)), 3, 5)\n", + "func(np.zeros((3, 4)), 3, 4)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "80c101fa-1c36-4cbf-a140-89d9136ad879", + "metadata": {}, + "outputs": [], + "source": [ + "foo = px.dim(\"foo\")\n", + "with pytest.raises(ValueError):\n", + " px.xtensor(\"z\", dims=[foo, foo])" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "id": "cde39455-8262-4e00-be37-903c654fc0a2", + "metadata": {}, + "outputs": [], + "source": [ + "def tensorize(inputs, outputs, *, check=True):\n", + " dims = {}\n", + " for input in inputs:\n", + " if isinstance(input, px.type.DimVariable):\n", + " dims[input.type] = input\n", + "\n", + " new_inputs = []\n", + " replacements = []\n", + " for input in inputs:\n", + " if isinstance(input, px.type.DimVariable):\n", + " replacements.append((input, input))\n", + " new_inputs.append(input)\n", + " else:\n", + " new_input = px.basic.tensor_from_xtensor(input).type()\n", + " replacement = px.xtensor_from_tensor(new_input, [dims[dim.type] for dim in input.dims], check=True)\n", + " replacements.append((input, replacement))\n", + " new_inputs.append(new_input)\n", + "\n", + " new_outputs = pytensor.clone_replace( outputs, replacements)\n", + " #new_inputs = [new_input for _, new_input in replacements]\n", + " return new_inputs, new_outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "id": "431d3bc3-ef15-469d-b6e3-744ef28537d4", + "metadata": {}, + "outputs": [], + "source": [ + "country = px.dim(\"country\")\n", + "treatment = px.dim(\"treatment\")\n", + "\n", + "effect = px.xtensor(\"effect\", dims=[country, treatment])\n", + "sigma = px.xtensor(\"sigma\", dims=[])\n", + "observed = px.xtensor(\"observed\", dims=[treatment, country])\n", + "\n", + "residual = ((effect - observed) / sigma) + (effect - observed).std()" + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "id": "27b3e654-8b51-4b78-a978-fec63724cf4d", + "metadata": {}, + "outputs": [], + "source": [ + "inputs = [country, treatment, effect, sigma, observed]\n", + "outputs = [residual]\n", + "\n", + "tensor_inputs, tensor_outputs = tensorize(inputs=inputs, outputs=outputs, check=True)\n", + "func = pytensor.function(tensor_inputs, tensor_outputs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (Pixi)", + "language": "python", + "name": "pixi-kernel-python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 7f1b9ecddb..ba54c7ac44 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -2,10 +2,12 @@ import pytensor.xtensor.rewriting from pytensor.xtensor import linalg, random +from pytensor.xtensor.basic import ones, xtensor_from_tensor, zeros from pytensor.xtensor.math import dot from pytensor.xtensor.shape import concat from pytensor.xtensor.type import ( as_xtensor, + dim, xtensor, xtensor_constant, ) diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 5c1f700b9f..ae739a86e1 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -1,9 +1,14 @@ -from collections.abc import Sequence - from pytensor.compile.ops import TypeCastingOp from pytensor.graph import Apply, Op +from pytensor.scalar.basic import uint64 +from pytensor.tensor.basic import ones as tensor_ones +from pytensor.tensor.basic import zeros as tensor_zeros +from pytensor.tensor.shape import specify_shape from pytensor.tensor.type import TensorType -from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor +from pytensor.xtensor.type import DimVariable, XTensorType, as_xtensor, xtensor + + +DIM_LENGTH_SCALAR = uint64 class XOp(Op): @@ -32,6 +37,7 @@ def make_node(self, x): return Apply(self, [x], [output]) def L_op(self, inputs, outs, g_outs): + # TODO fix [x] = inputs [g_out] = g_outs return [xtensor_from_tensor(g_out, dims=x.type.dims)] @@ -41,46 +47,49 @@ def L_op(self, inputs, outs, g_outs): class XTensorFromTensor(XTypeCastOp): - __props__ = ("dims",) - - def __init__(self, dims: Sequence[str]): - super().__init__() - self.dims = tuple(dims) + __props__ = () - def make_node(self, x): + def make_node(self, x, *dims): if not isinstance(x.type, TensorType): raise TypeError(f"x must be an TensorType type, got {type(x.type)}") - output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape) - return Apply(self, [x], [output]) + output = xtensor(dtype=x.type.dtype, dims=dims) + return Apply(self, [x, *dims], [output]) def L_op(self, inputs, outs, g_outs): + # TODO fix [g_out] = g_outs return [tensor_from_xtensor(g_out)] -def xtensor_from_tensor(x, dims, name=None): - return XTensorFromTensor(dims=dims)(x, name=name) +def xtensor_from_tensor(x, dims, name=None, check: bool = True): + if check: + x = specify_shape(x, [dim.size for dim in dims]) + return XTensorFromTensor()(x, *dims, name=name) -class Rename(XTypeCastOp): - __props__ = ("new_dims",) +class MapDims(XTypeCastOp): + __props__ = ("new_dim_indices",) - def __init__(self, new_dims: tuple[str, ...]): - super().__init__() - self.new_dims = new_dims + def __init__(self, new_dim_indices: tuple[int, ...]): + self.new_dims_indices = new_dim_indices - def make_node(self, x): + def make_node(self, x, *new_dims): x = as_xtensor(x) - output = x.type.clone(dims=self.new_dims)() + new_dims = list(x.dims) + for i, idx in enumerate(self.new_dims_indices): + new_dims[idx] = new_dims[i] + + output = x.type.clone(dims=new_dims)() return Apply(self, [x], [output]) def L_op(self, inputs, outs, g_outs): + # TODO fix [x] = inputs [g_out] = g_outs - return [rename(g_out, dims=x.type.dims)] + return [map_dims(g_out, dims=x.type.dims)] -def rename(x, name_dict: dict[str, str] | None = None, **names: str): +def map_dims(x, name_dict: dict[DimVariable, DimVariable] | None = None, **names): if name_dict is not None: if names: raise ValueError("Cannot use both positional and keyword names in rename") @@ -97,4 +106,30 @@ def rename(x, name_dict: dict[str, str] | None = None, **names: str): f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}" ) - return Rename(tuple(new_names))(x) + return MapDims(tuple(new_names))(x) + + +def zeros(*dims, dtype=None, name=None): + """Create a new XTensor filled with zeros.""" + if not dims: + raise ValueError("At least one dimension must be specified") + + return xtensor_from_tensor( + tensor_zeros(shape=[dim.size for dim in dims], dtype=dtype), + dims=dims, + name=name, + check=False, + ) + + +def ones(*dims, dtype=None, name=None): + """Create a new XTensor filled with zeros.""" + if not dims: + raise ValueError("At least one dimension must be specified") + + return xtensor_from_tensor( + tensor_ones(shape=[dim.size for dim in dims], dtype=dtype), + dims=dims, + name=name, + check=False, + ) diff --git a/pytensor/xtensor/dims.py b/pytensor/xtensor/dims.py new file mode 100644 index 0000000000..3d0d4dd3ae --- /dev/null +++ b/pytensor/xtensor/dims.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from uuid import uuid4 + +import numpy as np + +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op, Variable +from pytensor.xtensor.type import ( + DIM_LENGTH_TYPE, + DIM_LENGTH_VARIABLE, + BasicDim, + CloneDim, + DimType, + DimVariable, + XTensorVariable, +) + + +class DimOp(Op): + def perform(self, node, inputs, outputs): + raise NotImplementedError( + f"xtensor operation {self} must be lowered to equivalent tensor operations" + ) + + +# Not a dim op, because it doesn't return a DimVariable +class Length(Op): + __props__ = () + + def make_node(self, *inputs: Variable) -> Apply: + (x,) = inputs + if not isinstance(x, DimVariable): + raise TypeError(f"x must be a DimVariable, got {type(x.type)}") + return Apply(self, [x], [DIM_LENGTH_TYPE()]) + + def perform(self, node, inputs, outputs): + # outputs[0][0] = np.int64(inputs[0]) + outputs[0][0] = np.array(inputs[0], dtype=DIM_LENGTH_TYPE.dtype) + + +def _dim_size(dim: DimVariable) -> DIM_LENGTH_VARIABLE: + return Length()(dim) + + +class FromLength(DimOp): + __props__ = ("dim_type",) + + def __init__(self, dim_type: DimType): + super().__init__() + self.dim_type = dim_type + + def make_node(self, *inputs: Variable) -> Apply: + (length,) = inputs + if not isinstance(length, DIM_LENGTH_VARIABLE): + raise TypeError( + f"length must be a DIM_LENGTH_VARIABLE, got {type(length.type)}" + ) + if length.type != DIM_LENGTH_TYPE: + raise TypeError( + f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}" + ) + return Apply(self, [length], [self.dim_type()]) + + def perform(self, node, inputs, outputs): + """Convert the length to a list of lengths.""" + outputs[0][0] = inputs[0] + + +def from_length(length: DIM_LENGTH_VARIABLE, name: str | None = None) -> DimVariable: + # TODO add check for dtype + if not isinstance(length, DIM_LENGTH_VARIABLE): + raise TypeError( + f"length must be a DIM_LENGTH_VARIABLE, got {type(length.type)}" + ) + if length.type != DIM_LENGTH_TYPE: + raise TypeError( + f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}" + ) + + uuid = uuid4() + dim_type = BasicDim(uuid=uuid, name=name) + op = FromLength(dim_type) + return op(length, name=name) + + +class DimFromTensor(Op): + __props__ = ("dim_type",) + + def __init__(self, dim_type: DimType): + super().__init__() + self.dim_type = dim_type + + def make_node(self, *inputs: Variable) -> Apply: + (x,) = inputs + if not isinstance(x, XTensorVariable): + raise TypeError(f"x must be an XTensorVariable, got {type(x.type)}") + return Apply(self, [x], [self.dim_type()]) + + def perform(self, node, inputs, outputs): + """Convert the tensor to a dimension variable.""" + (x,) = inputs + (x_var,) = node.inputs + for i, dim in enumerate(x_var.type.dims): + if dim == self.dim_type: + # outputs[0][0] = np.int64(x.shape[i]) + outputs[0][0] = np.array(x.shape[i], dtype=DIM_LENGTH_TYPE.dtype) + return + raise ValueError(f"Dimension {self.dim_type} not found in tensor {x.type.dims}") + + +def _dim_from_tensor(x: XTensorVariable, idx: int) -> DimVariable: + op = DimFromTensor(dim_type=x.type.dims[idx]) + return op(x, name=x.type.dims[idx].name) + + +class Clone(Op): + __props__ = ("dim_type",) + + def __init__(self, dim_type): + super().__init__() + self.dim_type = dim_type + + def make_node(self, *inputs: Variable) -> Apply: + (x,) = inputs + if not isinstance(x, DimVariable): + raise TypeError(f"x must be a DimVariable, got {type(x.type)}") + return Apply(self, [x], [self.dim_type()]) + + def perform(self, node, inputs, outputs): + outputs[0][0] = inputs[0] + + +def _clone_dim(dim: DimVariable, *, name: str | None = None) -> DimVariable: + """Rename a dimension variable. + + Args: + name: The new name for the dimension. + + Returns: + A new DimVariable with the updated name. + """ + dim_type = CloneDim(uuid=uuid4(), base=dim.type, name=name) + return Clone(dim_type)(dim, name=name) + + +class Product(Op): + __props__ = () + + def make_node(self, *dims: Variable) -> Apply: + if not all(isinstance(dim, DimVariable) for dim in dims): + raise TypeError("All inputs must be DimVariables.") + out = dim_type() + return Apply(self, list(dims), [out]) + + def perform(self, node, inputs, outputs): + outputs[0][0] = np.prod(inputs, dtype=DIM_LENGTH_TYPE.dtype).item() + + +def product_dim(*dims: DimVariable, name: str | None = None) -> DimVariable: + return Product()(*dims, name=name) + + +def rebase_dim(dim: DimVariable, *tensors: XTensorVariable) -> DimVariable: + if not isinstance(dim, DimVariable): + raise TypeError(f"dim must be a DimVariable, got {type(dim)}") + + if not tensors: + raise ValueError("At least one tensor must be provided for rebasing.") + + for tensor in tensors: + for i, tensor_dim in enumerate(tensor.type.dims): + if dim.type == tensor_dim: + return _dim_from_tensor(tensor, idx=i) + raise ValueError(f"Dimension {dim.type} not found in any of the provided tensors.") diff --git a/pytensor/xtensor/reduction.py b/pytensor/xtensor/reduction.py index 300e480750..544fcefe42 100644 --- a/pytensor/xtensor/reduction.py +++ b/pytensor/xtensor/reduction.py @@ -9,20 +9,20 @@ from pytensor.xtensor.basic import XOp from pytensor.xtensor.math import neq, sqrt from pytensor.xtensor.math import sqr as square -from pytensor.xtensor.type import as_xtensor, xtensor +from pytensor.xtensor.type import DimType, DimVariable, as_xtensor, xtensor -REDUCE_DIM = str | Sequence[str] | EllipsisType | None +REDUCE_DIM = DimVariable | Sequence[DimVariable] | EllipsisType | None class XReduce(XOp): __slots__ = ("binary_op", "dims") - def __init__(self, binary_op, dims: Sequence[str]): + def __init__(self, binary_op, dims: Sequence[DimVariable]): super().__init__() self.binary_op = binary_op # Order of reduce dims doesn't change the behavior of the Op - self.dims = tuple(sorted(dims)) + self.dims = tuple(dims) def make_node(self, x): x = as_xtensor(x) @@ -43,17 +43,17 @@ def make_node(self, x): if d not in reduce_dims_set ] ) - output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + output = xtensor(dtype=x.type.dtype, dims=out_dims) return Apply(self, [x], [output]) -def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]: - if isinstance(dim, str): - return (dim,) +def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[DimType]: + if isinstance(dim, DimVariable): + return (dim.type,) elif dim is None or dim is Ellipsis: x = as_xtensor(x) - return typing.cast(tuple[str], x.type.dims) - return dim + return typing.cast(tuple[DimType], x.type.dims) + return tuple(dim.type for dim in dim) def reduce(x, dim: REDUCE_DIM = None, *, binary_op): @@ -80,8 +80,14 @@ def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op): def _infer_reduced_size(original_var, reduced_var): reduced_dims = reduced_var.dims - return variadic_mul( - *[size for dim, size in original_var.sizes if dim not in reduced_dims] + return as_xtensor( + variadic_mul( + *[ + size + for dim, size in original_var.sizes.items() + if dim not in reduced_dims + ] + ) ) @@ -96,7 +102,7 @@ def var(x, dim: REDUCE_DIM, *, ddof: int = 0): x = as_xtensor(x) x_mean = mean(x, dim) n = _infer_reduced_size(x, x_mean) - return square(x - x_mean) / (n - ddof) + return square(x - x_mean).mean(dim) / (n - ddof) def std(x, dim: REDUCE_DIM, *, ddof: int = 0): diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py index be93101426..9dad8441c3 100644 --- a/pytensor/xtensor/rewriting/basic.py +++ b/pytensor/xtensor/rewriting/basic.py @@ -1,12 +1,16 @@ from pytensor.graph import node_rewriter -from pytensor.tensor.basic import register_infer_shape -from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless +from pytensor.tensor.rewriting.basic import ( + register_canonicalize, + register_infer_shape, + register_useless, +) from pytensor.xtensor.basic import ( - Rename, + MapDims, TensorFromXTensor, XTensorFromTensor, xtensor_from_tensor, ) +from pytensor.xtensor.dims import DimFromTensor, FromLength, Length from pytensor.xtensor.rewriting.utils import register_lower_xtensor @@ -29,23 +33,51 @@ def useless_tensor_from_xtensor(fgraph, node): @node_rewriter(tracks=[XTensorFromTensor]) def useless_xtensor_from_tensor(fgraph, node): """XTensorFromTensor(TensorFromXTensor(x)) -> x""" - [x] = node.inputs + # TODO + [x, *dims] = node.inputs if x.owner and isinstance(x.owner.op, TensorFromXTensor): return [x.owner.inputs[0]] +@register_infer_shape +@register_useless +@register_canonicalize +@register_lower_xtensor +@node_rewriter(tracks=[Length]) +def useless_length(fgraph, node): + """Length(FromLength(x)) -> x""" + [dim] = node.inputs + if dim.owner and isinstance(dim.owner.op, FromLength): + return [dim.owner.inputs[0]] + + +@register_infer_shape +@register_useless +@register_canonicalize +@register_lower_xtensor +@node_rewriter(tracks=[DimFromTensor]) +def useless_dim_from_tensor(fgraph, node): + """DimFromTensor(XTensorFromTensor(..., dim)) -> dim""" + [x] = node.inputs + if x.owner and isinstance(x.owner.op, XTensorFromTensor): + dim_idx = x.type.dims.index(node.op.dim_type) + assert dim_idx != -1, "Dimension not found in XTensorFromTensor input" + [x_orig, *dims] = x.owner.inputs + return [dims[dim_idx]] + + @register_lower_xtensor @node_rewriter(tracks=[TensorFromXTensor]) def useless_tensor_from_xtensor_of_rename(fgraph, node): """TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)""" [renamed_x] = node.inputs - if renamed_x.owner and isinstance(renamed_x.owner.op, Rename): + if renamed_x.owner and isinstance(renamed_x.owner.op, MapDims): [x] = renamed_x.owner.inputs return node.op(x, return_list=True) @register_lower_xtensor -@node_rewriter(tracks=[Rename]) +@node_rewriter(tracks=[MapDims]) def useless_rename(fgraph, node): """ @@ -54,7 +86,7 @@ def useless_rename(fgraph, node): """ [renamed_x] = node.inputs if renamed_x.owner: - if isinstance(renamed_x.owner.op, Rename): + if isinstance(renamed_x.owner.op, MapDims): [x] = renamed_x.owner.inputs return [node.op(x)] elif isinstance(renamed_x.owner.op, TensorFromXTensor): diff --git a/pytensor/xtensor/rewriting/vectorization.py b/pytensor/xtensor/rewriting/vectorization.py index bed7da564b..6f7df737e3 100644 --- a/pytensor/xtensor/rewriting/vectorization.py +++ b/pytensor/xtensor/rewriting/vectorization.py @@ -3,6 +3,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.utils import compute_batch_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.dims import rebase_dim from pytensor.xtensor.rewriting.utils import register_lower_xtensor from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise @@ -10,15 +11,18 @@ @register_lower_xtensor @node_rewriter(tracks=[XElemwise]) def lower_elemwise(fgraph, node): - out_dims = node.outputs[0].type.dims + assert len(node.outputs) == 1 + out_dims = node.outputs[0].dims + out_dims = [rebase_dim(dim, *node.inputs) for dim in out_dims] + out_dim_types = [dim.type for dim in out_dims] # Convert input XTensors to Tensors and align batch dimensions tensor_inputs = [] for inp in node.inputs: - inp_dims = inp.type.dims + inp_dim_types = inp.type.dims order = [ - inp_dims.index(out_dim) if out_dim in inp_dims else "x" - for out_dim in out_dims + inp_dim_types.index(out_dim_type) if out_dim_type in inp_dim_types else "x" + for out_dim_type in out_dim_types ] tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) tensor_inputs.append(tensor_inp) @@ -29,7 +33,8 @@ def lower_elemwise(fgraph, node): # Convert output Tensors to XTensors new_outs = [ - xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs + xtensor_from_tensor(tensor_out, dims=out_dims, check=False) + for tensor_out in tensor_outs ] return new_outs diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index c5f345e45a..7e1bc0a77f 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import typing import warnings +from itertools import combinations from types import EllipsisType +from uuid import UUID, uuid4 from pytensor.compile import ( DeepCopyOp, @@ -8,11 +12,11 @@ register_deep_copy_op_c_code, register_view_op_c_code, ) +from pytensor.scalar.basic import ScalarType, ScalarVariable from pytensor.tensor import ( TensorType, _as_tensor_variable, as_tensor_variable, - specify_shape, ) from pytensor.tensor.math import variadic_mul @@ -25,7 +29,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import Any, Literal, TypeVar +from typing import Any, Literal, TypeVar, cast import numpy as np @@ -38,17 +42,299 @@ from pytensor.tensor.variable import TensorConstantSignature, TensorVariable +# I think uint64 would make more sense, but some code in tensor/rewrites/shape +# asserts that it is int64? +# DIM_LENGTH_TYPE = int64 +DIM_LENGTH_TYPE = TensorType(dtype="int64", shape=()) +DIM_LENGTH_VARIABLE = TensorVariable + + +class DimType(Type): + """A type for dimensions. + + If two dimensions share the same type, they must have the same + length. + """ + + __props__ = ("name", "size") + + name: str | None + size: int | None + + def __init__(self, *, name: str | None = None, size: int | None = None): + super().__init__() + self.name = name + self.size = size + + def base_dims(self) -> set[BasicDim]: + raise NotImplementedError( + "Subclasses must implement base_dims to return a set of base dimensions." + ) + + def filter(self, data, strict=False, allow_downcast=None): + # At runtime, a dim behaves like a DIM_LENGTH_SCALAR scalar + return DIM_LENGTH_TYPE.filter( + data, strict=strict, allow_downcast=allow_downcast + ) + + def filer_variable(self, other, allow_convert=True): + """Filter a variable to ensure it is a DimVariable.""" + if not isinstance(other, Variable): + raise ValueError() + + if isinstance(other.type, DimType): + return other + + if allow_convert: + other2 = self.convert_variable(other) + if other2 is not None: + return other2 + + raise TypeError( + f"Cannot convert Type {other.type} (of Variable {other}) into Type {self}." + ) + + def __repr__(self) -> str: + props = [] + for prop in self.__props__: + if not hasattr(self, prop): + raise AttributeError( + f"{self.__class__.__name__} has no property '{prop}' even though it is listed in __props__" + ) + value = getattr(self, prop) + if value is None: + continue + if prop == "name": + props.insert(0, f"{value}") + elif prop == "uuid": + props.append("uuid=?") + else: + props.append(f"{prop}={value!r}") + return f"{self.__class__.__name__}({', '.join(props)})" + + def dim_compatible(self, other: DimType): + """Test if the dimension is compatible with other dimensions. + + If two dimensions are compatible, they must have a common + dimension that they can broadcast to. Tensors can not contain + any dimensions that are compatible. + + dim compatibility *must* me reflexive, symmetric and transitive. + + It defaults to dim equality, but can be overridden by subclasses. + """ + return self == other + + def broadcasted_dim_type(self, *other: DimType) -> DimType | None: + """Find the smallest dimension that all dimensions can broadcast to. + + Note, that this does not correspond to the usual numpy broadcasting, + but will be used mostly to broadcast dimensions that are subsets + of some larger dimension. + + If the dimensions are not compatible, it returns None. + """ + if all(self.dim_compatible(o) for o in other): + return self + return None + + def broadcast_dim(self, dim_var: DimVariable, target_type: DimType) -> DimVariable: + """Broadcast this dimension to the given broadcast_dim. + + If the dimensions are not compatible, it raises a ValueError. + """ + if target_type == self: + return dim_var + + if not self.dim_compatible(target_type): + raise ValueError( + f"Cannot broadcast {self} to {target_type}. " + "Dimensions must be compatible." + ) + raise NotImplementedError("Subclass did not implent dim broadcasting") + + +class BasicDim(DimType): + """A non-derived dimension type.""" + + __props__ = (*DimType.__props__, "uuid") + + uuid: UUID | None = None + + def __init__(self, *, uuid: UUID | None = None, **kwargs): + super().__init__(**kwargs) + self.uuid = uuid + + def base_dims(self) -> set[BasicDim]: + return {self} + + +class SubsetDim(DimType): + __props__ = (*DimType.__props__, "base", "subset") + + +class ProductDim(DimType): + __props__ = ( + *DimType.__props__, + "dims", + ) + + dims: tuple[DimType, ...] + + def __init__(self, *, dims: Sequence[DimType], **kwargs): + super().__init__(**kwargs) + self.dims = tuple(dims) + + def base_dims(self) -> set[BasicDim]: + base = set() + for dim in self.dims: + base = set.union(base, dim.base_dims()) + return base + + +class ConcatDim(DimType): + __props__ = ( + *DimType.__props__, + "dims", + ) + + dims: tuple[DimType, ...] + + def __init__(self, *, dims: Sequence[DimType], **kwargs): + super().__init__(**kwargs) + self.dims = tuple(dims) + + def base_dims(self) -> set[BasicDim]: + base = set() + for dim in self.dims: + base = set.union(base, dim.base_dims()) + return base + + +class ConstSliceDim(DimType): + __props__ = (*DimType.__props__, "base", "slice") + + base: DimType + slice: slice # [int | None, int | None, int | None] + + def __init__(self, *, base: DimType, slice: slice, **kwargs): + super().__init__(**kwargs) + self.base = base + self.slice = slice + + def base_dims(self) -> set[BasicDim]: + return self.base.base_dims() + + +class UnknownIndexedDim(DimType): + __props__ = (*DimType.__props__, "base", "uuid") + + base: DimType + uuid: UUID + + def __init__(self, *, base: DimType, uuid: UUID, **kwargs): + super().__init__(**kwargs) + self.base = base + self.uuid = uuid + + def base_dims(self) -> set[BasicDim]: + return self.base.base_dims() + + +class CloneDim(DimType): + __props__ = (*DimType.__props__, "base", "uuid") + + base: DimType + uuid: UUID + + def __init__(self, *, base: DimType, uuid: UUID, **kwargs): + super().__init__(**kwargs) + self.base = base + self.uuid = uuid + + def base_dims(self) -> set[BasicDim]: + return self.base.base_dims() + + +class DimVariable(Variable[DimType, OptionalApplyType]): + def clone_dim(self, name: str | None = None) -> DimVariable: + """Rename the dimension variable.""" + from pytensor.xtensor.dims import _clone_dim + + return _clone_dim(self, name=name) + + @property + def size(self) -> ScalarVariable: + """Return the length of the dimension variable.""" + import pytensor.xtensor.dims as px_dims + + return px_dims._dim_size(self) + + +class ConstantDim(Constant[DimType], DimVariable): + def __repr__(self, firstPass=True) -> str: + if self.name is None: + return f"UnnamedDim({int(self.data)})" + else: + return f"{self.name}({int(self.data)})" + + +DimType.variable_type = DimVariable +DimType.constant_type = ConstantDim + +_unknown_dim_counter: int = 0 + + +def _new_dim_name() -> str: + global _unknown_dim_counter + count = _unknown_dim_counter + _unknown_dim_counter += 1 + return f"dim{count}" + + +def dim( + name: str | None = None, + size: DimVariable | ScalarVariable | TensorVariable | int | None = None, +) -> DimVariable: + """Create a dimension variable.""" + + if name is None: + name = _new_dim_name() + if size is None: + dim_type = BasicDim(name=name, uuid=uuid4()) + return cast(DimVariable, dim_type.make_variable(name=name)) + if isinstance(size, int): + dim_type = BasicDim(size=size, name=name, uuid=uuid4()) + return cast(DimVariable, dim_type.make_constant(value=size, name=name)) + if isinstance(size, ScalarVariable): + size = as_tensor_variable(size) + if isinstance(size, DIM_LENGTH_VARIABLE): + if size.type != DIM_LENGTH_TYPE: + raise TypeError( + f"length must be a DIM_LENGTH_SCALAR scalar, got {size.type} for {name}" + ) + from pytensor.xtensor.dims import from_length + + return from_length(size, name=name) + if isinstance(size, DimVariable): + return size.clone_dim(name=name) + raise TypeError( + f"length must be an int or a DIM_LENGTH_SCALAR scalar, got {type(size)} for {name}" + ) + + class XTensorType(Type, HasDataType, HasShape): """A `Type` for Xtensors (Xarray-like tensors with dims).""" __props__ = ("dtype", "shape", "dims") + dims: tuple[DimType, ...] + def __init__( self, dtype: str | np.dtype, *, - dims: Sequence[str], - shape: Sequence[int | None] | None = None, + dims: Sequence[DimType], name: str | None = None, ): if dtype == "floatX": @@ -59,14 +345,13 @@ def __init__( self.dims = tuple(dims) if len(set(dims)) < len(dims): raise ValueError(f"Dimensions must be unique. Found duplicates in {dims}: ") - if shape is None: - self.shape = (None,) * len(self.dims) - else: - self.shape = tuple(shape) - if len(self.shape) != len(self.dims): + + for dim1, dim2 in combinations(dims, r=2): + if dim1.dim_compatible(dim2): raise ValueError( - f"Shape {self.shape} must have the same length as dims {self.dims}" + f"Dimensions {dim1} and {dim2} are compatible, but must be distinct. Clone one of them." ) + self.shape = tuple(dim.size for dim in self.dims) self.ndim = len(self.dims) self.name = name self.numpy_dtype = np.dtype(self.dtype) @@ -87,12 +372,12 @@ def clone( dims = self.dims if shape is None: shape = self.shape - return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs) + return type(self)(dtype=dtype, dims=dims, **kwargs) - def filter(self, value, strict=False, allow_downcast=None): + def filter(self, data, strict=False, allow_downcast=None): # XTensorType behaves like TensorType at runtime, so we filter the same way. return TensorType.filter( - self, value, strict=strict, allow_downcast=allow_downcast + self, data, strict=strict, allow_downcast=allow_downcast ) @staticmethod @@ -118,11 +403,12 @@ def filter_variable(self, other, allow_convert=True): f"You can try to manually convert {other} into a {self}. " ) - def convert_variable(self, var): + def convert_variable(self, var: Variable): var_type = var.type if self.is_super(var_type): return var if isinstance(var_type, XTensorType): + var = cast(XTensorVariable, var) if ( self.ndim != var_type.ndim or self.dtype != var_type.dtype @@ -136,32 +422,12 @@ def convert_variable(self, var): if self.is_super(var_type): return var - if any( - s_length is not None - and var_length is not None - and s_length != var_length - for s_length, var_length in zip(self.shape, var_type.shape) - ): - # Incompatible static shapes - return None - - # Needs a specify_shape - return as_xtensor(specify_shape(var.values, self.shape), dims=self.dims) + return var if isinstance(var_type, TensorType): - if ( - self.ndim != var_type.ndim - or self.dtype != var_type.dtype - or any( - s_length is not None - and var_length is not None - and s_length != var_length - for s_length, var_length in zip(self.shape, var_type.shape) - ) - ): - return None - else: - return as_xtensor(specify_shape(var, self.shape), dims=self.dims) + var = cast(TensorVariable, var) + if self.ndim == 0 and var.ndim == 0: + return as_xtensor(var, dims=()) return None @@ -179,29 +445,20 @@ def __eq__(self, other): and self.shape == other.shape ) - def is_super(self, otype): + def is_super(self, otype: Type): if type(self) is not type(otype): return False - if self.dtype != otype.dtype: - return False - if self.dims != otype.dims: - return False - if any( - s_dim_length is not None and s_dim_length != o_dim_length - for s_dim_length, o_dim_length in zip(self.shape, otype.shape) - ): - return False - return True + otype = cast(XTensorType, otype) + return self == otype def xtensor( name: str | None = None, *, - dims: Sequence[str], - shape: Sequence[int | None] | None = None, + dims: Sequence[DimVariable], dtype: str | np.dtype = "floatX", ): - return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name) + return XTensorType(dtype=dtype, dims=tuple(dim.type for dim in dims))(name=name) _XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType) @@ -356,12 +613,16 @@ def coords(self): raise NotImplementedError("coords not implemented for XTensorVariable") @property - def dims(self) -> tuple[str, ...]: - return self.type.dims + def dims(self) -> tuple[DimVariable, ...]: + from pytensor.xtensor.dims import _dim_from_tensor + + return tuple( + _dim_from_tensor(self, idx) for idx, _ in enumerate(self.type.dims) + ) @property - def sizes(self) -> dict[str, TensorVariable]: - return dict(zip(self.dims, self.shape)) + def sizes(self) -> dict[DimType, TensorVariable]: + return dict(zip(self.type.dims, self.shape)) @property def as_numpy(self): @@ -376,7 +637,7 @@ def ndim(self) -> int: @property def shape(self) -> tuple[TensorVariable, ...]: - return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore + return tuple(as_tensor_variable(dim.size) for dim in self.dims) # type: ignore @property def size(self) -> TensorVariable: @@ -767,6 +1028,7 @@ def signature(self): def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): + # TODO check this function for changes with dim objects x_dims: tuple[str, ...] if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): xarray_dims = x.dims @@ -795,7 +1057,7 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): ) try: return XTensorConstant( - XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape), + XTensorType(dtype=x_data.dtype, dims=x_dims), x_data, name=name, ) @@ -810,7 +1072,9 @@ def as_symbolic_xarray(x, **kwargs): return xtensor_constant(x, **kwargs) -def as_xtensor(x, name=None, dims: Sequence[str] | None = None): +def as_xtensor( + x, name=None, dims: Sequence[DimVariable] | None = None +) -> XTensorVariable: if isinstance(x, Apply): if len(x.outputs) != 1: raise ValueError( @@ -837,10 +1101,13 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None): "non-scalar TensorVariable cannot be converted to XTensorVariable without dims." ) return px.basic.xtensor_from_tensor(x, dims=dims, name=name) - else: - raise TypeError( - "Variable with type {x.type} cannot be converted to XTensorVariable." - ) + + if isinstance(x.type, ScalarType): + # Convert scalar to XTensorVariable with no dims + return as_xtensor(as_tensor_variable(x), name=name, dims=dims) + raise TypeError( + f"Variable with type {x.type} cannot be converted to XTensorVariable." + ) try: return xtensor_constant(x, dims=dims, name=name) except TypeError as err: diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py index 8243e78170..1165dada9d 100644 --- a/pytensor/xtensor/vectorization.py +++ b/pytensor/xtensor/vectorization.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from itertools import chain import numpy as np @@ -13,24 +14,32 @@ get_static_shape_from_size_variables, ) from pytensor.xtensor.basic import XOp -from pytensor.xtensor.type import as_xtensor, xtensor +from pytensor.xtensor.type import ( + DimType, + DimVariable, + XTensorVariable, + as_xtensor, + xtensor, +) -def combine_dims_and_shape(inputs): - dims_and_shape: dict[str, int | None] = {} +def broadcast_xtensors(inputs: Sequence[XTensorVariable]) -> list[DimVariable]: + dims_and_shape: dict[DimType, int | None] = {} + dim_to_dimvar: dict[DimType, DimVariable] = {} for inp in inputs: - for dim, dim_length in zip(inp.type.dims, inp.type.shape): - if dim not in dims_and_shape: - dims_and_shape[dim] = dim_length - elif dim_length is not None: + for dim, dim_length in zip(inp.dims, inp.type.shape): + if dim.type not in dims_and_shape: + dims_and_shape[dim.type] = dim_length + if dim.type not in dim_to_dimvar: + dim_to_dimvar[dim.type] = dim + + if dim_length is not None: # Check for conflicting shapes - if (dims_and_shape[dim] is not None) and ( - dims_and_shape[dim] != dim_length + if (dims_and_shape[dim.type] is not None) and ( + dims_and_shape[dim.type] != dim_length ): raise ValueError(f"Dimension {dim} has conflicting shapes") - # Keep the non-None shape - dims_and_shape[dim] = dim_length - return dims_and_shape + return list(dim_to_dimvar.values()) class XElemwise(XOp): @@ -47,18 +56,14 @@ def make_node(self, *inputs): f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}" ) - dims_and_shape = combine_dims_and_shape(inputs) - if dims_and_shape: - output_dims, output_shape = zip(*dims_and_shape.items()) - else: - output_dims, output_shape = (), () + output_dims = broadcast_xtensors(inputs) dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs] output_dtypes = [ out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs ] outputs = [ - xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape) + xtensor(dtype=output_dtype, dims=output_dims) for output_dtype in output_dtypes ] return Apply(self, inputs, outputs) @@ -85,7 +90,7 @@ def make_node(self, *inputs): f"Wrong number of inputs, expected {len(self.core_dims[0])}, got {len(inputs)}" ) - dims_and_shape = combine_dims_and_shape(inputs) + dims_and_shape = broadcast_xtensors(inputs) core_inputs_dims, core_outputs_dims = self.core_dims core_input_dims_set = set(chain.from_iterable(core_inputs_dims)) @@ -216,7 +221,7 @@ def make_node(self, rng, *extra_dim_lengths_and_params): self.extra_dims, get_static_shape_from_size_variables(extra_dim_lengths) ) ) - params_dims_and_shape = combine_dims_and_shape(params) + params_dims_and_shape = broadcast_xtensors(params) # Check that no parameter dims conflict with size dims if conflict_dims := set(extra_dims_and_shape).intersection( diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 376532f8ab..c287a8294d 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -13,7 +13,7 @@ import pytensor.xtensor.math as pxm from pytensor import function from pytensor.scalar import ScalarOp -from pytensor.xtensor.basic import rename +from pytensor.xtensor.basic import map_dims from pytensor.xtensor.math import add, exp from pytensor.xtensor.type import xtensor from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function @@ -77,8 +77,8 @@ def test_dimension_alignment(): def test_renamed_dimension_alignment(): x = xtensor("x", dims=("a", "b1", "b2"), shape=(2, 3, 3)) - y = rename(x, b1="b2", b2="b1") - z = rename(x, b2="b3") + y = map_dims(x, b1="b2", b2="b1") + z = map_dims(x, b2="b3") assert y.type.dims == ("a", "b2", "b1") assert z.type.dims == ("a", "b1", "b3")