-
Notifications
You must be signed in to change notification settings - Fork 137
Implement xarray-like labeled tensors and semantics #1411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+5,471
−52
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
588494c
Avoid no-op DimShuffle
ricardoV94 068229e
Use DimShuffle instead of Reshape in `ix_`
ricardoV94 4d15f7f
Extract ViewOp functionality into a base TypeCastOp
ricardoV94 fb11cc5
Implement basic labeled tensor functionality
ricardoV94 d753f3e
Implement stack for XTensorVariables
ricardoV94 bec40a2
Implement Elemwise and Blockwise operations for XTensorVariables
ricardoV94 8012134
Implement cast for XTensorVariables
ricardoV94 3e7a803
Implement reduction operations for XTensorVariables
ricardoV94 5442a2d
Implement concat for XTensorVariables
ricardoV94 1cb5289
Implement transpose for XTensorVariables
AllenDowney 5c62008
Implement unstack for XTensorVariables
OriolAbril a14ec0b
Implement index for XTensorVariables
ricardoV94 b08b395
Implement index update for XTensorVariables
ricardoV94 a60d0ec
Implement diff for XTensorVariables
ricardoV94 86d58b9
Implement squeeze for XTensorVariables
AllenDowney 6e2bd64
Implement expand_dims for XTensorVariables (#1449)
AllenDowney 81cbec8
Implement dot for XTensorVariables (#1475)
AllenDowney 41d9be4
Implement XTensorVariable version of RandomVariables
ricardoV94 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import warnings | ||
|
||
import pytensor.xtensor.rewriting | ||
from pytensor.xtensor import linalg | ||
from pytensor.xtensor.math import dot | ||
from pytensor.xtensor.shape import concat | ||
from pytensor.xtensor.type import ( | ||
as_xtensor, | ||
xtensor, | ||
xtensor_constant, | ||
) | ||
|
||
|
||
warnings.warn("xtensor module is experimental and full of bugs") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from collections.abc import Sequence | ||
|
||
from pytensor.compile.ops import TypeCastingOp | ||
from pytensor.graph import Apply, Op | ||
from pytensor.tensor.type import TensorType | ||
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor | ||
|
||
|
||
class XOp(Op): | ||
"""A base class for XOps that shouldn't be materialized""" | ||
|
||
def perform(self, node, inputs, outputs): | ||
raise NotImplementedError( | ||
f"xtensor operation {self} must be lowered to equivalent tensor operations" | ||
) | ||
|
||
|
||
class XTypeCastOp(TypeCastingOp): | ||
"""Base class for Ops that type cast between TensorType and XTensorType. | ||
This is like a `ViewOp` but without the expectation the input and output have identical types. | ||
""" | ||
|
||
|
||
class TensorFromXTensor(XTypeCastOp): | ||
__props__ = () | ||
|
||
def make_node(self, x): | ||
if not isinstance(x.type, XTensorType): | ||
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") | ||
output = TensorType(x.type.dtype, shape=x.type.shape)() | ||
return Apply(self, [x], [output]) | ||
|
||
def L_op(self, inputs, outs, g_outs): | ||
[x] = inputs | ||
[g_out] = g_outs | ||
return [xtensor_from_tensor(g_out, dims=x.type.dims)] | ||
|
||
|
||
tensor_from_xtensor = TensorFromXTensor() | ||
|
||
|
||
class XTensorFromTensor(XTypeCastOp): | ||
__props__ = ("dims",) | ||
|
||
def __init__(self, dims: Sequence[str]): | ||
super().__init__() | ||
self.dims = tuple(dims) | ||
|
||
def make_node(self, x): | ||
if not isinstance(x.type, TensorType): | ||
raise TypeError(f"x must be an TensorType type, got {type(x.type)}") | ||
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape) | ||
return Apply(self, [x], [output]) | ||
|
||
def L_op(self, inputs, outs, g_outs): | ||
[g_out] = g_outs | ||
return [tensor_from_xtensor(g_out)] | ||
|
||
|
||
def xtensor_from_tensor(x, dims, name=None): | ||
return XTensorFromTensor(dims=dims)(x, name=name) | ||
|
||
|
||
class Rename(XTypeCastOp): | ||
__props__ = ("new_dims",) | ||
|
||
def __init__(self, new_dims: tuple[str, ...]): | ||
super().__init__() | ||
self.new_dims = new_dims | ||
|
||
def make_node(self, x): | ||
x = as_xtensor(x) | ||
output = x.type.clone(dims=self.new_dims)() | ||
return Apply(self, [x], [output]) | ||
|
||
def L_op(self, inputs, outs, g_outs): | ||
[x] = inputs | ||
[g_out] = g_outs | ||
return [rename(g_out, dims=x.type.dims)] | ||
|
||
|
||
def rename(x, name_dict: dict[str, str] | None = None, **names: str): | ||
if name_dict is not None: | ||
if names: | ||
raise ValueError("Cannot use both positional and keyword names in rename") | ||
names = name_dict | ||
|
||
x = as_xtensor(x) | ||
old_names = x.type.dims | ||
new_names = list(old_names) | ||
for old_name, new_name in names.items(): | ||
try: | ||
new_names[old_names.index(old_name)] = new_name | ||
except ValueError: | ||
raise ValueError( | ||
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}" | ||
) | ||
|
||
return Rename(tuple(new_names))(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
# HERE LIE DRAGONS | ||
# Useful links to make sense of all the numpy/xarray complexity | ||
# https://numpy.org/devdocs//user/basics.indexing.html | ||
# https://numpy.org/neps/nep-0021-advanced-indexing.html | ||
# https://docs.xarray.dev/en/latest/user-guide/indexing.html | ||
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html | ||
from typing import Literal | ||
|
||
from pytensor.graph.basic import Apply, Constant, Variable | ||
from pytensor.scalar.basic import discrete_dtypes | ||
from pytensor.tensor.basic import as_tensor | ||
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice | ||
from pytensor.xtensor.basic import XOp, xtensor_from_tensor | ||
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor | ||
|
||
|
||
def as_idx_variable(idx, indexed_dim: str): | ||
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): | ||
raise TypeError( | ||
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead" | ||
) | ||
if isinstance(idx, slice): | ||
idx = make_slice(idx) | ||
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): | ||
pass | ||
elif ( | ||
isinstance(idx, tuple) | ||
and len(idx) == 2 | ||
and ( | ||
isinstance(idx[0], str) | ||
or ( | ||
isinstance(idx[0], tuple | list) | ||
and all(isinstance(d, str) for d in idx[0]) | ||
) | ||
) | ||
): | ||
# Special case for ("x", array) that xarray supports | ||
dim, idx = idx | ||
if isinstance(idx, Variable) and isinstance(idx.type, XTensorType): | ||
raise IndexError( | ||
f"Giving a dimension name to an XTensorVariable indexer is not supported: {(dim, idx)}. " | ||
"Use .rename() instead." | ||
) | ||
if isinstance(dim, str): | ||
dims = (dim,) | ||
else: | ||
dims = tuple(dim) | ||
idx = as_xtensor(as_tensor(idx), dims=dims) | ||
else: | ||
# Must be integer / boolean indices, we already counted for None and slices | ||
try: | ||
idx = as_xtensor(idx) | ||
except TypeError: | ||
idx = as_tensor(idx) | ||
if idx.type.ndim > 1: | ||
# Same error that xarray raises | ||
raise IndexError( | ||
"Unlabeled multi-dimensional array cannot be used for indexing" | ||
) | ||
# This is implicitly an XTensorVariable with dim matching the indexed one | ||
idx = xtensor_from_tensor(idx, dims=(indexed_dim,)[: idx.type.ndim]) | ||
|
||
if idx.type.dtype == "bool": | ||
if idx.type.ndim != 1: | ||
# xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379 | ||
# Otherwise, it is always restricted to 1d boolean indexing arrays | ||
raise NotImplementedError( | ||
"Only 1d boolean indexing arrays are supported" | ||
) | ||
if idx.type.dims != (indexed_dim,): | ||
raise IndexError( | ||
"Boolean indexer should be unlabeled or on the same dimension to the indexed array. " | ||
f"Indexer is on {idx.type.dims} but the target dimension is {indexed_dim}." | ||
) | ||
|
||
# Convert to nonzero indices | ||
idx = as_xtensor(idx.values.nonzero()[0], dims=idx.type.dims) | ||
|
||
elif idx.type.dtype not in discrete_dtypes: | ||
raise TypeError("Numerical indices must be integers or boolean") | ||
return idx | ||
|
||
|
||
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: | ||
if dim_length is None: | ||
return None | ||
if isinstance(slc, Constant): | ||
d = slc.data | ||
start, stop, step = d.start, d.stop, d.step | ||
elif slc.owner is None: | ||
# It's a root variable no way of knowing what we're getting | ||
return None | ||
else: | ||
# It's a MakeSliceOp | ||
start, stop, step = slc.owner.inputs | ||
if isinstance(start, Constant): | ||
start = start.data | ||
else: | ||
return None | ||
if isinstance(stop, Constant): | ||
stop = stop.data | ||
else: | ||
return None | ||
if isinstance(step, Constant): | ||
step = step.data | ||
else: | ||
return None | ||
return len(range(*slice(start, stop, step).indices(dim_length))) | ||
|
||
|
||
class Index(XOp): | ||
__props__ = () | ||
|
||
def make_node(self, x, *idxs): | ||
x = as_xtensor(x) | ||
|
||
if any(idx is Ellipsis for idx in idxs): | ||
if idxs.count(Ellipsis) > 1: | ||
raise IndexError("an index can only have a single ellipsis ('...')") | ||
# Convert intermediate Ellipsis to slice(None) | ||
ellipsis_loc = idxs.index(Ellipsis) | ||
n_implied_none_slices = x.type.ndim - (len(idxs) - 1) | ||
idxs = ( | ||
*idxs[:ellipsis_loc], | ||
*((slice(None),) * n_implied_none_slices), | ||
*idxs[ellipsis_loc + 1 :], | ||
) | ||
|
||
x_ndim = x.type.ndim | ||
x_dims = x.type.dims | ||
x_shape = x.type.shape | ||
out_dims = [] | ||
out_shape = [] | ||
|
||
def combine_dim_info(idx_dim, idx_dim_shape): | ||
if idx_dim not in out_dims: | ||
# First information about the dimension length | ||
out_dims.append(idx_dim) | ||
out_shape.append(idx_dim_shape) | ||
else: | ||
# Dim already introduced in output by a previous index | ||
# Update static shape or raise if incompatible | ||
out_dim_pos = out_dims.index(idx_dim) | ||
out_dim_shape = out_shape[out_dim_pos] | ||
if out_dim_shape is None: | ||
# We don't know the size of the dimension yet | ||
out_shape[out_dim_pos] = idx_dim_shape | ||
elif idx_dim_shape is not None and idx_dim_shape != out_dim_shape: | ||
raise IndexError( | ||
f"Dimension of indexers mismatch for dim {idx_dim}" | ||
) | ||
|
||
if len(idxs) > x_ndim: | ||
raise IndexError("Too many indices") | ||
|
||
idxs = [ | ||
as_idx_variable(idx, dim) for idx, dim in zip(idxs, x_dims, strict=False) | ||
] | ||
|
||
for i, idx in enumerate(idxs): | ||
if isinstance(idx.type, SliceType): | ||
idx_dim = x_dims[i] | ||
idx_dim_shape = get_static_slice_length(idx, x_shape[i]) | ||
combine_dim_info(idx_dim, idx_dim_shape) | ||
else: | ||
if idx.type.ndim == 0: | ||
# Scalar index, dimension is dropped | ||
continue | ||
|
||
assert isinstance(idx.type, XTensorType) | ||
|
||
idx_dims = idx.type.dims | ||
for idx_dim in idx_dims: | ||
idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)] | ||
combine_dim_info(idx_dim, idx_dim_shape) | ||
|
||
for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]): | ||
# Add back any unindexed dimensions | ||
if dim_i not in out_dims: | ||
# If the dimension was not indexed, we keep it as is | ||
combine_dim_info(dim_i, shape_i) | ||
|
||
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) | ||
return Apply(self, [x, *idxs], [output]) | ||
|
||
|
||
index = Index() | ||
|
||
|
||
class IndexUpdate(XOp): | ||
__props__ = ("mode",) | ||
|
||
def __init__(self, mode: Literal["set", "inc"]): | ||
if mode not in ("set", "inc"): | ||
raise ValueError("mode must be 'set' or 'inc'") | ||
self.mode = mode | ||
|
||
def make_node(self, x, y, *idxs): | ||
# Call Index on (x, *idxs) to process inputs and infer output type | ||
x_view_node = index.make_node(x, *idxs) | ||
x, *idxs = x_view_node.inputs | ||
[x_view] = x_view_node.outputs | ||
|
||
try: | ||
y = as_xtensor(y) | ||
except TypeError: | ||
y = as_xtensor(as_tensor(y), dims=x_view.type.dims) | ||
|
||
if not set(y.type.dims).issubset(x_view.type.dims): | ||
raise ValueError( | ||
f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}" | ||
) | ||
|
||
out = x.type() | ||
return Apply(self, [x, y, *idxs], [out]) | ||
|
||
|
||
index_assignment = IndexUpdate("set") | ||
index_increment = IndexUpdate("inc") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from collections.abc import Sequence | ||
from typing import Literal | ||
|
||
from pytensor.tensor.slinalg import Cholesky, Solve | ||
from pytensor.xtensor.type import as_xtensor | ||
from pytensor.xtensor.vectorization import XBlockwise | ||
|
||
|
||
def cholesky( | ||
x, | ||
lower: bool = True, | ||
*, | ||
check_finite: bool = False, | ||
overwrite_a: bool = False, | ||
on_error: Literal["raise", "nan"] = "raise", | ||
dims: Sequence[str], | ||
): | ||
if len(dims) != 2: | ||
raise ValueError(f"Cholesky needs two dims, got {len(dims)}") | ||
|
||
core_op = Cholesky( | ||
lower=lower, | ||
check_finite=check_finite, | ||
overwrite_a=overwrite_a, | ||
on_error=on_error, | ||
) | ||
core_dims = ( | ||
((dims[0], dims[1]),), | ||
((dims[0], dims[1]),), | ||
) | ||
x_op = XBlockwise(core_op, core_dims=core_dims) | ||
return x_op(x) | ||
|
||
|
||
def solve( | ||
a, | ||
b, | ||
dims: Sequence[str], | ||
assume_a="gen", | ||
lower: bool = False, | ||
check_finite: bool = False, | ||
): | ||
a, b = as_xtensor(a), as_xtensor(b) | ||
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]] | ||
output_core_dims: tuple[tuple[str] | tuple[str, str]] | ||
if len(dims) == 2: | ||
b_ndim = 1 | ||
[m1_dim] = [dim for dim in dims if dim not in b.type.dims] | ||
m2_dim = dims[0] if dims[0] != m1_dim else dims[1] | ||
input_core_dims = ((m1_dim, m2_dim), (m2_dim,)) | ||
# The shared dim disappears in the output | ||
output_core_dims = ((m1_dim,),) | ||
elif len(dims) == 3: | ||
b_ndim = 2 | ||
[n_dim] = [dim for dim in dims if dim not in a.type.dims] | ||
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim] | ||
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim)) | ||
# The shared dim disappears in the output | ||
output_core_dims = ((m1_dim, n_dim),) | ||
else: | ||
raise ValueError("Solve dims must have length 2 or 3") | ||
|
||
core_op = Solve( | ||
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite | ||
) | ||
x_op = XBlockwise( | ||
core_op, | ||
core_dims=(input_core_dims, output_core_dims), | ||
) | ||
return x_op(a, b) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
import sys | ||
from collections.abc import Iterable | ||
from types import EllipsisType | ||
|
||
import numpy as np | ||
|
||
import pytensor.scalar as ps | ||
from pytensor import config | ||
from pytensor.graph.basic import Apply | ||
from pytensor.scalar import ScalarOp | ||
from pytensor.scalar.basic import _cast_mapping, upcast | ||
from pytensor.xtensor.basic import XOp, as_xtensor | ||
from pytensor.xtensor.type import xtensor | ||
from pytensor.xtensor.vectorization import XElemwise | ||
|
||
|
||
this_module = sys.modules[__name__] | ||
|
||
|
||
def _as_xelemwise(core_op: ScalarOp) -> XElemwise: | ||
out = XElemwise(core_op) | ||
out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables" | ||
return out | ||
|
||
|
||
abs = _as_xelemwise(ps.abs) | ||
add = _as_xelemwise(ps.add) | ||
logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_) | ||
angle = _as_xelemwise(ps.angle) | ||
arccos = _as_xelemwise(ps.arccos) | ||
arccosh = _as_xelemwise(ps.arccosh) | ||
arcsin = _as_xelemwise(ps.arcsin) | ||
arcsinh = _as_xelemwise(ps.arcsinh) | ||
arctan = _as_xelemwise(ps.arctan) | ||
arctan2 = _as_xelemwise(ps.arctan2) | ||
arctanh = _as_xelemwise(ps.arctanh) | ||
betainc = _as_xelemwise(ps.betainc) | ||
betaincinv = _as_xelemwise(ps.betaincinv) | ||
ceil = _as_xelemwise(ps.ceil) | ||
clip = _as_xelemwise(ps.clip) | ||
complex = _as_xelemwise(ps.complex) | ||
conjugate = conj = _as_xelemwise(ps.conj) | ||
cos = _as_xelemwise(ps.cos) | ||
cosh = _as_xelemwise(ps.cosh) | ||
deg2rad = _as_xelemwise(ps.deg2rad) | ||
equal = eq = _as_xelemwise(ps.eq) | ||
erf = _as_xelemwise(ps.erf) | ||
erfc = _as_xelemwise(ps.erfc) | ||
erfcinv = _as_xelemwise(ps.erfcinv) | ||
erfcx = _as_xelemwise(ps.erfcx) | ||
erfinv = _as_xelemwise(ps.erfinv) | ||
exp = _as_xelemwise(ps.exp) | ||
exp2 = _as_xelemwise(ps.exp2) | ||
expm1 = _as_xelemwise(ps.expm1) | ||
floor = _as_xelemwise(ps.floor) | ||
floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div) | ||
gamma = _as_xelemwise(ps.gamma) | ||
gammainc = _as_xelemwise(ps.gammainc) | ||
gammaincc = _as_xelemwise(ps.gammaincc) | ||
gammainccinv = _as_xelemwise(ps.gammainccinv) | ||
gammaincinv = _as_xelemwise(ps.gammaincinv) | ||
gammal = _as_xelemwise(ps.gammal) | ||
gammaln = _as_xelemwise(ps.gammaln) | ||
gammau = _as_xelemwise(ps.gammau) | ||
greater_equal = ge = _as_xelemwise(ps.ge) | ||
greater = gt = _as_xelemwise(ps.gt) | ||
hyp2f1 = _as_xelemwise(ps.hyp2f1) | ||
i0 = _as_xelemwise(ps.i0) | ||
i1 = _as_xelemwise(ps.i1) | ||
identity = _as_xelemwise(ps.identity) | ||
imag = _as_xelemwise(ps.imag) | ||
logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert) | ||
isinf = _as_xelemwise(ps.isinf) | ||
isnan = _as_xelemwise(ps.isnan) | ||
iv = _as_xelemwise(ps.iv) | ||
ive = _as_xelemwise(ps.ive) | ||
j0 = _as_xelemwise(ps.j0) | ||
j1 = _as_xelemwise(ps.j1) | ||
jv = _as_xelemwise(ps.jv) | ||
kve = _as_xelemwise(ps.kve) | ||
less_equal = le = _as_xelemwise(ps.le) | ||
log = _as_xelemwise(ps.log) | ||
log10 = _as_xelemwise(ps.log10) | ||
log1mexp = _as_xelemwise(ps.log1mexp) | ||
log1p = _as_xelemwise(ps.log1p) | ||
log2 = _as_xelemwise(ps.log2) | ||
less = lt = _as_xelemwise(ps.lt) | ||
mod = _as_xelemwise(ps.mod) | ||
multiply = mul = _as_xelemwise(ps.mul) | ||
negative = neg = _as_xelemwise(ps.neg) | ||
not_equal = neq = _as_xelemwise(ps.neq) | ||
logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_) | ||
owens_t = _as_xelemwise(ps.owens_t) | ||
polygamma = _as_xelemwise(ps.polygamma) | ||
power = pow = _as_xelemwise(ps.pow) | ||
psi = _as_xelemwise(ps.psi) | ||
rad2deg = _as_xelemwise(ps.rad2deg) | ||
real = _as_xelemwise(ps.real) | ||
reciprocal = _as_xelemwise(ps.reciprocal) | ||
round = _as_xelemwise(ps.round_half_to_even) | ||
maximum = _as_xelemwise(ps.scalar_maximum) | ||
minimum = _as_xelemwise(ps.scalar_minimum) | ||
second = _as_xelemwise(ps.second) | ||
sigmoid = _as_xelemwise(ps.sigmoid) | ||
sign = _as_xelemwise(ps.sign) | ||
sin = _as_xelemwise(ps.sin) | ||
sinh = _as_xelemwise(ps.sinh) | ||
softplus = _as_xelemwise(ps.softplus) | ||
square = sqr = _as_xelemwise(ps.sqr) | ||
sqrt = _as_xelemwise(ps.sqrt) | ||
subtract = sub = _as_xelemwise(ps.sub) | ||
where = switch = _as_xelemwise(ps.switch) | ||
tan = _as_xelemwise(ps.tan) | ||
tanh = _as_xelemwise(ps.tanh) | ||
tri_gamma = _as_xelemwise(ps.tri_gamma) | ||
true_divide = true_div = _as_xelemwise(ps.true_div) | ||
trunc = _as_xelemwise(ps.trunc) | ||
logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor) | ||
|
||
_xelemwise_cast_op: dict[str, XElemwise] = {} | ||
|
||
|
||
def cast(x, dtype): | ||
if dtype == "floatX": | ||
dtype = config.floatX | ||
else: | ||
dtype = np.dtype(dtype).name | ||
|
||
x = as_xtensor(x) | ||
if x.type.dtype == dtype: | ||
return x | ||
if x.type.dtype.startswith("complex") and not dtype.startswith("complex"): | ||
raise TypeError( | ||
"Casting from complex to real is ambiguous: consider" | ||
" real(), imag(), angle() or abs()" | ||
) | ||
|
||
if dtype not in _xelemwise_cast_op: | ||
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype]) | ||
return _xelemwise_cast_op[dtype](x) | ||
|
||
|
||
def softmax(x, dim=None): | ||
exp_x = exp(x) | ||
return exp_x / exp_x.sum(dim=dim) | ||
|
||
|
||
class XDot(XOp): | ||
"""Matrix multiplication between two XTensorVariables. | ||
This operation performs matrix multiplication between two tensors, automatically | ||
aligning and contracting dimensions. The behavior matches xarray's dot operation. | ||
Parameters | ||
---------- | ||
dims : tuple of str | ||
The dimensions to contract over. If None, will contract over all matching dimensions. | ||
""" | ||
|
||
__props__ = ("dims",) | ||
|
||
def __init__(self, dims: Iterable[str]): | ||
self.dims = dims | ||
super().__init__() | ||
|
||
def make_node(self, x, y): | ||
x = as_xtensor(x) | ||
y = as_xtensor(y) | ||
|
||
x_shape_dict = dict(zip(x.type.dims, x.type.shape)) | ||
y_shape_dict = dict(zip(y.type.dims, y.type.shape)) | ||
|
||
# Check for dimension size mismatches (concrete only) | ||
for dim in self.dims: | ||
x_shape = x_shape_dict.get(dim, None) | ||
y_shape = y_shape_dict.get(dim, None) | ||
if ( | ||
isinstance(x_shape, int) | ||
and isinstance(y_shape, int) | ||
and x_shape != y_shape | ||
): | ||
raise ValueError(f"Size of dim '{dim}' does not match") | ||
|
||
# Determine output dimensions | ||
shape_dict = {**x_shape_dict, **y_shape_dict} | ||
out_dims = tuple(d for d in shape_dict if d not in self.dims) | ||
|
||
# Determine output shape | ||
out_shape = tuple(shape_dict[d] for d in out_dims) | ||
|
||
# Determine output dtype | ||
out_dtype = upcast(x.type.dtype, y.type.dtype) | ||
|
||
out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims) | ||
return Apply(self, [x, y], [out]) | ||
|
||
|
||
def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): | ||
"""Matrix multiplication between two XTensorVariables. | ||
This operation performs matrix multiplication between two tensors, automatically | ||
aligning and contracting dimensions. The behavior matches xarray's dot operation. | ||
Parameters | ||
---------- | ||
x : XTensorVariable | ||
First input tensor | ||
y : XTensorVariable | ||
Second input tensor | ||
dim : str, Iterable[Hashable], EllipsisType, or None, optional | ||
The dimensions to contract over. If None, will contract over all matching dimensions. | ||
If Ellipsis (...), will contract over all dimensions. | ||
Returns | ||
------- | ||
XTensorVariable | ||
The result of the matrix multiplication. | ||
Examples | ||
-------- | ||
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) | ||
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) | ||
>>> z = dot(x, y) # Result has dimensions ("a", "c") | ||
>>> z = dot(x, y, dim=...) # Contract over all dimensions | ||
""" | ||
x = as_xtensor(x) | ||
y = as_xtensor(y) | ||
|
||
x_dims = set(x.type.dims) | ||
y_dims = set(y.type.dims) | ||
intersection = x_dims & y_dims | ||
union = x_dims | y_dims | ||
|
||
# Canonicalize dims | ||
if dim is None: | ||
dim_set = intersection | ||
elif dim is ...: | ||
dim_set = union | ||
elif isinstance(dim, str): | ||
dim_set = {dim} | ||
elif isinstance(dim, Iterable): | ||
dim_set = set(dim) | ||
|
||
# Validate provided dims | ||
# Check if any dimension is not found in either input | ||
for d in dim_set: | ||
if d not in union: | ||
raise ValueError(f"Dimension {d} not found in either input") | ||
|
||
result = XDot(dims=tuple(dim_set))(x, y) | ||
|
||
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from collections.abc import Sequence | ||
from functools import wraps | ||
from typing import Literal | ||
|
||
import pytensor.tensor.random.basic as ptr | ||
from pytensor.graph.basic import Variable | ||
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.xtensor import as_xtensor | ||
from pytensor.xtensor.math import sqrt | ||
from pytensor.xtensor.vectorization import XRV | ||
|
||
|
||
def _as_xrv( | ||
core_op: RandomVariable, | ||
core_inps_dims_map: Sequence[Sequence[int]] | None = None, | ||
core_out_dims_map: Sequence[int] | None = None, | ||
): | ||
"""Helper function to define an XRV constructor. | ||
Parameters | ||
---------- | ||
core_op : RandomVariable | ||
The core random variable operation to wrap. | ||
core_inps_dims_map : Sequence[Sequence[int]] | None, optional | ||
A sequence of sequences mapping the core dimensions (specified by the user) | ||
for each input parameter. This is used when lowering to a RandomVariable operation, | ||
to decide the ordering of the core dimensions for each input. | ||
If None, it assumes the core dimensions are positional from left to right. | ||
core_out_dims_map : Sequence[int] | None, optional | ||
A sequence mapping the core dimensions (specified by the user) for the output variable. | ||
This is used when lowering to a RandomVariable operation, | ||
to decide the ordering of the core dimensions for the output. | ||
If None, it assumes the core dimensions are positional from left to right. | ||
""" | ||
if core_inps_dims_map is None: | ||
# Assume core_dims map positionally from left to right | ||
core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params] | ||
if core_out_dims_map is None: | ||
# Assume core_dims map positionally from left to right | ||
core_out_dims_map = tuple(range(core_op.ndim_supp)) | ||
|
||
core_dims_needed = max( | ||
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0 | ||
) | ||
|
||
@wraps(core_op) | ||
def xrv_constructor( | ||
*params, | ||
core_dims: Sequence[str] | str | None = None, | ||
extra_dims: dict[str, Variable] | None = None, | ||
rng: Variable | None = None, | ||
): | ||
if core_dims is None: | ||
core_dims = () | ||
if core_dims_needed: | ||
raise ValueError( | ||
f"{core_op.name} needs {core_dims_needed} core_dims to be specified" | ||
) | ||
elif isinstance(core_dims, str): | ||
core_dims = (core_dims,) | ||
|
||
if len(core_dims) != core_dims_needed: | ||
raise ValueError( | ||
f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}" | ||
) | ||
|
||
full_input_core_dims = tuple( | ||
tuple(core_dims[i] for i in inp_dims_map) | ||
for inp_dims_map in core_inps_dims_map | ||
) | ||
full_output_core_dims = tuple(core_dims[i] for i in core_out_dims_map) | ||
full_core_dims = (full_input_core_dims, full_output_core_dims) | ||
|
||
if extra_dims is None: | ||
extra_dims = {} | ||
|
||
return XRV( | ||
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys()) | ||
)(rng, *extra_dims.values(), *params) | ||
|
||
return xrv_constructor | ||
|
||
|
||
bernoulli = _as_xrv(ptr.bernoulli) | ||
beta = _as_xrv(ptr.beta) | ||
betabinom = _as_xrv(ptr.betabinom) | ||
binomial = _as_xrv(ptr.binomial) | ||
categorical = _as_xrv(ptr.categorical) | ||
cauchy = _as_xrv(ptr.cauchy) | ||
dirichlet = _as_xrv(ptr.dirichlet) | ||
exponential = _as_xrv(ptr.exponential) | ||
gamma = _as_xrv(ptr._gamma) | ||
gengamma = _as_xrv(ptr.gengamma) | ||
geometric = _as_xrv(ptr.geometric) | ||
gumbel = _as_xrv(ptr.gumbel) | ||
halfcauchy = _as_xrv(ptr.halfcauchy) | ||
halfnormal = _as_xrv(ptr.halfnormal) | ||
hypergeometric = _as_xrv(ptr.hypergeometric) | ||
integers = _as_xrv(ptr.integers) | ||
invgamma = _as_xrv(ptr.invgamma) | ||
laplace = _as_xrv(ptr.laplace) | ||
logistic = _as_xrv(ptr.logistic) | ||
lognormal = _as_xrv(ptr.lognormal) | ||
multinomial = _as_xrv(ptr.multinomial) | ||
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial) | ||
normal = _as_xrv(ptr.normal) | ||
pareto = _as_xrv(ptr.pareto) | ||
poisson = _as_xrv(ptr.poisson) | ||
t = _as_xrv(ptr.t) | ||
triangular = _as_xrv(ptr.triangular) | ||
truncexpon = _as_xrv(ptr.truncexpon) | ||
uniform = _as_xrv(ptr.uniform) | ||
vonmises = _as_xrv(ptr.vonmises) | ||
wald = _as_xrv(ptr.wald) | ||
weibull = _as_xrv(ptr.weibull) | ||
|
||
|
||
def multivariate_normal( | ||
mean, | ||
cov, | ||
*, | ||
core_dims: Sequence[str], | ||
extra_dims=None, | ||
rng=None, | ||
method: Literal["cholesky", "svd", "eigh"] = "cholesky", | ||
): | ||
mean = as_xtensor(mean) | ||
if len(core_dims) != 2: | ||
raise ValueError( | ||
f"multivariate_normal requires 2 core_dims, got {len(core_dims)}" | ||
) | ||
|
||
# Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov | ||
# This will be the core dimension of the output | ||
if core_dims[0] not in mean.type.dims: | ||
core_dims = core_dims[::-1] | ||
|
||
xop = _as_xrv(ptr.MvNormalRV(method=method)) | ||
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng) | ||
|
||
|
||
def standard_normal( | ||
extra_dims: dict[str, Variable] | None = None, | ||
rng: Variable | None = None, | ||
): | ||
"""Standard normal random variable.""" | ||
return normal(0, 1, extra_dims=extra_dims, rng=rng) | ||
|
||
|
||
def chisquare( | ||
df, | ||
extra_dims: dict[str, Variable] | None = None, | ||
rng: Variable | None = None, | ||
): | ||
"""Chi-square random variable.""" | ||
return gamma(df / 2.0, 2.0, extra_dims=extra_dims, rng=rng) | ||
|
||
|
||
def rayleigh( | ||
scale, | ||
extra_dims: dict[str, Variable] | None = None, | ||
rng: Variable | None = None, | ||
): | ||
"""Rayleigh random variable.""" | ||
|
||
df = scale * 0 + 2 # Poor man's broadcasting, to pass dimensions of scale to the RV | ||
return sqrt(chisquare(df, extra_dims=extra_dims, rng=rng)) * scale | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# XTensor Module | ||
|
||
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray. | ||
|
||
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute, | ||
that labels the dimensions of the tensor. | ||
|
||
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects. | ||
|
||
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations. | ||
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into | ||
a regular tensor graph that can itself be evaluated as usual. | ||
|
||
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray. | ||
If the existing XOps can be composed to produce the desired result, then we can use them directly. | ||
|
||
## Coordinates | ||
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`. | ||
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor. | ||
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor. | ||
|
||
## Example | ||
|
||
```python | ||
import pytensor.tensor as pt | ||
import pytensor.xtensor as px | ||
|
||
a = pt.tensor("a", shape=(3,)) | ||
b = pt.tensor("b", shape=(4,)) | ||
|
||
ax = px.as_xtensor(a, dims=["x"]) | ||
bx = px.as_xtensor(b, dims=["y"]) | ||
|
||
zx = ax + bx | ||
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4)) | ||
|
||
z = zx.values | ||
z.dprint() | ||
# TensorFromXTensor [id A] | ||
# └─ XElemwise{scalar_op=Add()} [id B] | ||
# ├─ XTensorFromTensor{dims=('x',)} [id C] | ||
# │ └─ a [id D] | ||
# └─ XTensorFromTensor{dims=('y',)} [id E] | ||
# └─ b [id F] | ||
``` | ||
|
||
Once we compile the graph, no `XOp`s are left. | ||
|
||
```python | ||
import pytensor | ||
|
||
with pytensor.config.change_flags(optimizer_verbose=True): | ||
fn = pytensor.function([a, b], z) | ||
|
||
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0) | ||
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None | ||
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None | ||
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) | ||
|
||
fn.dprint() | ||
# Add [id A] 2 | ||
# ├─ ExpandDims{axis=1} [id B] 1 | ||
# │ └─ a [id C] | ||
# └─ ExpandDims{axis=0} [id D] 0 | ||
# └─ b [id E] | ||
``` | ||
|
||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import typing | ||
from collections.abc import Sequence | ||
from functools import partial | ||
from types import EllipsisType | ||
|
||
import pytensor.scalar as ps | ||
from pytensor.graph.basic import Apply | ||
from pytensor.tensor.math import variadic_mul | ||
from pytensor.xtensor.basic import XOp | ||
from pytensor.xtensor.math import neq, sqrt | ||
from pytensor.xtensor.math import sqr as square | ||
from pytensor.xtensor.type import as_xtensor, xtensor | ||
|
||
|
||
REDUCE_DIM = str | Sequence[str] | EllipsisType | None | ||
|
||
|
||
class XReduce(XOp): | ||
__slots__ = ("binary_op", "dims") | ||
|
||
def __init__(self, binary_op, dims: Sequence[str]): | ||
super().__init__() | ||
self.binary_op = binary_op | ||
# Order of reduce dims doesn't change the behavior of the Op | ||
self.dims = tuple(sorted(dims)) | ||
|
||
def make_node(self, x): | ||
x = as_xtensor(x) | ||
x_dims = x.type.dims | ||
x_dims_set = set(x_dims) | ||
reduce_dims_set = set(self.dims) | ||
if x_dims_set == reduce_dims_set: | ||
out_dims, out_shape = [], [] | ||
else: | ||
if not reduce_dims_set.issubset(x_dims_set): | ||
raise ValueError( | ||
f"Reduced dims {self.dims} not found in array dimensions {x_dims}." | ||
) | ||
out_dims, out_shape = zip( | ||
*[ | ||
(d, s) | ||
for d, s in zip(x_dims, x.type.shape) | ||
if d not in reduce_dims_set | ||
] | ||
) | ||
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) | ||
return Apply(self, [x], [output]) | ||
|
||
|
||
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]: | ||
if isinstance(dim, str): | ||
return (dim,) | ||
elif dim is None or dim is Ellipsis: | ||
x = as_xtensor(x) | ||
return typing.cast(tuple[str], x.type.dims) | ||
return dim | ||
|
||
|
||
def reduce(x, dim: REDUCE_DIM = None, *, binary_op): | ||
dims = _process_user_dims(x, dim) | ||
return XReduce(binary_op=binary_op, dims=dims)(x) | ||
|
||
|
||
sum = partial(reduce, binary_op=ps.add) | ||
prod = partial(reduce, binary_op=ps.mul) | ||
max = partial(reduce, binary_op=ps.scalar_maximum) | ||
min = partial(reduce, binary_op=ps.scalar_minimum) | ||
|
||
|
||
def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op): | ||
x = as_xtensor(x) | ||
if x.type.dtype != "bool": | ||
x = neq(x, 0) | ||
return reduce(x, dim=dim, binary_op=binary_op) | ||
|
||
|
||
all = partial(bool_reduce, binary_op=ps.and_) | ||
any = partial(bool_reduce, binary_op=ps.or_) | ||
|
||
|
||
def _infer_reduced_size(original_var, reduced_var): | ||
reduced_dims = reduced_var.dims | ||
return variadic_mul( | ||
*[size for dim, size in original_var.sizes if dim not in reduced_dims] | ||
) | ||
|
||
|
||
def mean(x, dim: REDUCE_DIM): | ||
x = as_xtensor(x) | ||
sum_x = sum(x, dim) | ||
n = _infer_reduced_size(x, sum_x) | ||
return sum_x / n | ||
|
||
|
||
def var(x, dim: REDUCE_DIM, *, ddof: int = 0): | ||
x = as_xtensor(x) | ||
x_mean = mean(x, dim) | ||
n = _infer_reduced_size(x, x_mean) | ||
return square(x - x_mean) / (n - ddof) | ||
|
||
|
||
def std(x, dim: REDUCE_DIM, *, ddof: int = 0): | ||
return sqrt(var(x, dim, ddof=ddof)) | ||
|
||
|
||
class XCumReduce(XOp): | ||
__props__ = ("binary_op", "dims") | ||
|
||
def __init__(self, binary_op, dims: Sequence[str]): | ||
self.binary_op = binary_op | ||
self.dims = tuple(sorted(dims)) # Order doesn't matter | ||
|
||
def make_node(self, x): | ||
x = as_xtensor(x) | ||
out = x.type() | ||
return Apply(self, [x], [out]) | ||
|
||
|
||
def cumreduce(x, dim: REDUCE_DIM, *, binary_op): | ||
dims = _process_user_dims(x, dim) | ||
return XCumReduce(dims=dims, binary_op=binary_op)(x) | ||
|
||
|
||
cumsum = partial(cumreduce, binary_op=ps.add) | ||
cumprod = partial(cumreduce, binary_op=ps.mul) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import pytensor.xtensor.rewriting.basic | ||
import pytensor.xtensor.rewriting.indexing | ||
import pytensor.xtensor.rewriting.math | ||
import pytensor.xtensor.rewriting.reduction | ||
import pytensor.xtensor.rewriting.shape | ||
import pytensor.xtensor.rewriting.vectorization |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from pytensor.graph import node_rewriter | ||
from pytensor.tensor.basic import register_infer_shape | ||
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless | ||
from pytensor.xtensor.basic import ( | ||
Rename, | ||
TensorFromXTensor, | ||
XTensorFromTensor, | ||
xtensor_from_tensor, | ||
) | ||
from pytensor.xtensor.rewriting.utils import register_lower_xtensor | ||
|
||
|
||
@register_infer_shape | ||
@register_useless | ||
@register_canonicalize | ||
@register_lower_xtensor | ||
@node_rewriter(tracks=[TensorFromXTensor]) | ||
def useless_tensor_from_xtensor(fgraph, node): | ||
"""TensorFromXTensor(XTensorFromTensor(x)) -> x""" | ||
[x] = node.inputs | ||
if x.owner and isinstance(x.owner.op, XTensorFromTensor): | ||
return [x.owner.inputs[0]] | ||
|
||
|
||
@register_infer_shape | ||
@register_useless | ||
@register_canonicalize | ||
@register_lower_xtensor | ||
@node_rewriter(tracks=[XTensorFromTensor]) | ||
def useless_xtensor_from_tensor(fgraph, node): | ||
"""XTensorFromTensor(TensorFromXTensor(x)) -> x""" | ||
[x] = node.inputs | ||
if x.owner and isinstance(x.owner.op, TensorFromXTensor): | ||
return [x.owner.inputs[0]] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[TensorFromXTensor]) | ||
def useless_tensor_from_xtensor_of_rename(fgraph, node): | ||
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)""" | ||
[renamed_x] = node.inputs | ||
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename): | ||
[x] = renamed_x.owner.inputs | ||
return node.op(x, return_list=True) | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[Rename]) | ||
def useless_rename(fgraph, node): | ||
""" | ||
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims) | ||
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims) | ||
""" | ||
[renamed_x] = node.inputs | ||
if renamed_x.owner: | ||
if isinstance(renamed_x.owner.op, Rename): | ||
[x] = renamed_x.owner.inputs | ||
return [node.op(x)] | ||
elif isinstance(renamed_x.owner.op, TensorFromXTensor): | ||
[x] = renamed_x.owner.inputs | ||
return [xtensor_from_tensor(x, dims=node.op.new_dims)] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
from itertools import zip_longest | ||
|
||
from pytensor import as_symbolic | ||
from pytensor.graph import Constant, node_rewriter | ||
from pytensor.tensor import TensorType, arange, specify_shape | ||
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor | ||
from pytensor.tensor.type_other import NoneTypeT, SliceType | ||
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor | ||
from pytensor.xtensor.indexing import Index, IndexUpdate, index | ||
from pytensor.xtensor.rewriting.utils import register_lower_xtensor | ||
from pytensor.xtensor.type import XTensorType | ||
|
||
|
||
def to_basic_idx(idx): | ||
if isinstance(idx.type, SliceType): | ||
if isinstance(idx, Constant): | ||
return idx.data | ||
elif idx.owner: | ||
# MakeSlice Op | ||
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible | ||
return slice( | ||
*[ | ||
None if isinstance(i.type, NoneTypeT) else i | ||
for i in idx.owner.inputs | ||
] | ||
) | ||
else: | ||
return idx | ||
if ( | ||
isinstance(idx.type, XTensorType) | ||
and idx.type.ndim == 0 | ||
and idx.type.dtype != bool | ||
): | ||
return idx.values | ||
raise TypeError("Cannot convert idx to basic idx") | ||
|
||
|
||
def _lower_index(node): | ||
"""Lower XTensorVariable indexing to regular TensorVariable indexing. | ||
xarray-like indexing has two modes: | ||
1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices. | ||
2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing. | ||
An Index Op can combine both modes. | ||
To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing. | ||
We expand the dims of each index so they are as large as the number of output dimensions, place the indices that | ||
belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes. | ||
For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]], | ||
This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have | ||
indices that have more than one dimension at the start. | ||
In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension). | ||
However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices | ||
we have to convert the slices to equivalent advanced indices. | ||
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed, | ||
and then indexing it with the original slice. This index is then handled as a regular advanced index. | ||
Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy. | ||
When all advanced indices are consecutive, the respective view is located in the "original" location. | ||
However, if advanced indices are separated by basic indices (slices in our case), the output views | ||
always show up at the front of the array. This information is returned as the second output of this function, | ||
which labels the final position of the indexed dimensions under this rule. | ||
""" | ||
|
||
assert isinstance(node.op, Index) | ||
|
||
x, *idxs = node.inputs | ||
[out] = node.outputs | ||
x_tensor_indexed_dims = out.type.dims | ||
x_tensor = tensor_from_xtensor(x) | ||
|
||
if all( | ||
( | ||
isinstance(idx.type, SliceType) | ||
or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0) | ||
) | ||
for idx in idxs | ||
): | ||
# Special case having just basic indexing | ||
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)] | ||
|
||
else: | ||
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing | ||
# May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index | ||
x_dims = x.type.dims | ||
x_shape = tuple(x.shape) | ||
out_ndim = out.type.ndim | ||
out_dims = out.type.dims | ||
aligned_idxs = [] | ||
basic_idx_axis = [] | ||
# zip_longest adds the implicit slice(None) | ||
for i, (idx, x_dim) in enumerate( | ||
zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None))) | ||
): | ||
if isinstance(idx.type, SliceType): | ||
if not any( | ||
( | ||
isinstance(other_idx.type, XTensorType) | ||
and x_dim in other_idx.dims | ||
) | ||
for j, other_idx in enumerate(idxs) | ||
if j != i | ||
): | ||
# We can use basic indexing directly if no other index acts on this dimension | ||
# This is an optimization that avoids creating an unnecessary arange tensor | ||
# and facilitates the use of the specialized AdvancedSubtensor1 when possible | ||
aligned_idxs.append(idx) | ||
basic_idx_axis.append(out_dims.index(x_dim)) | ||
else: | ||
# Otherwise we need to convert the basic index into an equivalent advanced indexing | ||
# And align it so it interacts correctly with the other advanced indices | ||
adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)] | ||
ds_order = ["x"] * out_ndim | ||
ds_order[out_dims.index(x_dim)] = 0 | ||
aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order)) | ||
else: | ||
assert isinstance(idx.type, XTensorType) | ||
if idx.type.ndim == 0: | ||
# Scalar index, we can use it directly | ||
aligned_idxs.append(idx.values) | ||
else: | ||
# Vector index, we need to align the indexing dimensions with the base_dims | ||
ds_order = ["x"] * out_ndim | ||
for j, idx_dim in enumerate(idx.dims): | ||
ds_order[out_dims.index(idx_dim)] = j | ||
aligned_idxs.append(idx.values.dimshuffle(ds_order)) | ||
|
||
# Squeeze indexing dimensions that were not used because we kept basic indexing slices | ||
if basic_idx_axis: | ||
aligned_idxs = [ | ||
idx.squeeze(axis=basic_idx_axis) | ||
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) | ||
else idx | ||
for idx in aligned_idxs | ||
] | ||
|
||
x_tensor_indexed = x_tensor[tuple(aligned_idxs)] | ||
|
||
if basic_idx_axis and _non_consecutive_adv_indexing(aligned_idxs): | ||
# Numpy moves advanced indexing dimensions to the front when they are not consecutive | ||
# We need to transpose them back to the expected output order | ||
x_tensor_indexed_basic_dims = [out_dims[axis] for axis in basic_idx_axis] | ||
x_tensor_indexed_dims = [ | ||
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims | ||
] + x_tensor_indexed_basic_dims | ||
|
||
return x_tensor_indexed, x_tensor_indexed_dims | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[Index]) | ||
def lower_index(fgraph, node): | ||
"""Lower XTensorVariable indexing to regular TensorVariable indexing. | ||
The bulk of the work is done by `_lower_index`, except for special logic to control the | ||
location of non-consecutive advanced indices, and to preserve static shape information. | ||
""" | ||
|
||
[out] = node.outputs | ||
out_dims = out.type.dims | ||
|
||
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node) | ||
if x_tensor_indexed_dims != out_dims: | ||
# Numpy moves advanced indexing dimensions to the front when they are not consecutive | ||
# We need to transpose them back to the expected output order | ||
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims] | ||
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order) | ||
|
||
# Add lost shape information | ||
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) | ||
|
||
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims) | ||
return [new_out] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[IndexUpdate]) | ||
def lower_index_update(fgraph, node): | ||
"""Lower XTensorVariable index update to regular TensorVariable indexing update. | ||
This rewrite requires converting the index view to a tensor-based equivalent expression, | ||
just like `lower_index`. It then requires aligning the dimensions of y with the | ||
dimensions of the index view, with special care for non-consecutive dimensions being | ||
pulled to the front axis according to numpy rules. | ||
""" | ||
x, y, *idxs = node.inputs | ||
|
||
# Lower the indexing part first | ||
indexed_node = index.make_node(x, *idxs) | ||
x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node) | ||
y_tensor = tensor_from_xtensor(y) | ||
|
||
# Align dimensions of y with those of the indexed tensor x | ||
y_dims = y.type.dims | ||
y_dims_set = set(y_dims) | ||
y_order = tuple( | ||
y_dims.index(x_dim) if x_dim in y_dims_set else "x" | ||
for x_dim in x_tensor_indexed_dims | ||
) | ||
# Remove useless left expand_dims | ||
while len(y_order) > 0 and y_order[0] == "x": | ||
y_order = y_order[1:] | ||
if y_order != tuple(range(y_tensor.type.ndim)): | ||
y_tensor = y_tensor.dimshuffle(y_order) | ||
|
||
x_tensor_updated = inc_subtensor( | ||
x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set" | ||
) | ||
new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims) | ||
return [new_out] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from string import ascii_lowercase | ||
|
||
from pytensor.graph import node_rewriter | ||
from pytensor.tensor import einsum | ||
from pytensor.tensor.shape import specify_shape | ||
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor | ||
from pytensor.xtensor.math import XDot | ||
from pytensor.xtensor.rewriting.utils import register_lower_xtensor | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[XDot]) | ||
def lower_dot(fgraph, node): | ||
"""Rewrite XDot to tensor.dot. | ||
This rewrite converts an XDot operation to a tensor-based dot operation, | ||
handling dimension alignment and contraction. | ||
""" | ||
[x, y] = node.inputs | ||
[out] = node.outputs | ||
|
||
# Convert inputs to tensors | ||
x_tensor = tensor_from_xtensor(x) | ||
y_tensor = tensor_from_xtensor(y) | ||
|
||
# Collect all dimension names across inputs and output | ||
all_dims = list( | ||
dict.fromkeys(x.type.dims + y.type.dims + out.type.dims) | ||
) # preserve order | ||
if len(all_dims) > len(ascii_lowercase): | ||
raise ValueError("Too many dimensions to map to einsum subscripts") | ||
|
||
dim_to_char = dict(zip(all_dims, ascii_lowercase)) | ||
|
||
# Build einsum string | ||
x_subs = "".join(dim_to_char[d] for d in x.type.dims) | ||
y_subs = "".join(dim_to_char[d] for d in y.type.dims) | ||
out_subs = "".join(dim_to_char[d] for d in out.type.dims) | ||
einsum_str = f"{x_subs},{y_subs}->{out_subs}" | ||
|
||
# Perform the einsum operation | ||
out_tensor = einsum(einsum_str, x_tensor, y_tensor) | ||
|
||
# Reshape to match the output shape | ||
out_tensor = specify_shape(out_tensor, out.type.shape) | ||
|
||
return [xtensor_from_tensor(out_tensor, out.type.dims)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from functools import partial | ||
|
||
import pytensor.scalar as ps | ||
from pytensor.graph.rewriting.basic import node_rewriter | ||
from pytensor.tensor.extra_ops import CumOp | ||
from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum | ||
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor | ||
from pytensor.xtensor.reduction import XCumReduce, XReduce | ||
from pytensor.xtensor.rewriting.utils import register_lower_xtensor | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[XReduce]) | ||
def lower_reduce(fgraph, node): | ||
[x] = node.inputs | ||
[out] = node.outputs | ||
x_dims = x.type.dims | ||
reduce_dims = node.op.dims | ||
reduce_axis = [x_dims.index(dim) for dim in reduce_dims] | ||
|
||
if not reduce_axis: | ||
return [x] | ||
|
||
match node.op.binary_op: | ||
case ps.add: | ||
tensor_op_class = Sum | ||
case ps.mul: | ||
tensor_op_class = Prod | ||
case ps.and_: | ||
tensor_op_class = All | ||
case ps.or_: | ||
tensor_op_class = Any | ||
case ps.scalar_maximum: | ||
tensor_op_class = Max | ||
case ps.scalar_minimum: | ||
tensor_op_class = Min | ||
case _: | ||
# Case without known/predefined Ops | ||
tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op) | ||
|
||
x_tensor = tensor_from_xtensor(x) | ||
out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor) | ||
new_out = xtensor_from_tensor(out_tensor, out.type.dims) | ||
return [new_out] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[XCumReduce]) | ||
def lower_cumreduce(fgraph, node): | ||
[x] = node.inputs | ||
x_dims = x.type.dims | ||
reduce_dims = node.op.dims | ||
reduce_axis = [x_dims.index(dim) for dim in reduce_dims] | ||
|
||
if not reduce_axis: | ||
return [x] | ||
|
||
match node.op.binary_op: | ||
case ps.add: | ||
tensor_op_class = partial(CumOp, mode="add") | ||
case ps.mul: | ||
tensor_op_class = partial(CumOp, mode="mul") | ||
case _: | ||
# We don't know how to convert an arbitrary binary cum/reduce Op | ||
return None | ||
|
||
# Each dim corresponds to an application of Cumsum/Cumprod | ||
out_tensor = tensor_from_xtensor(x) | ||
for axis in reduce_axis: | ||
out_tensor = tensor_op_class(axis=axis)(out_tensor) | ||
out = xtensor_from_tensor(out_tensor, x.type.dims) | ||
return [out] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from pytensor.graph import node_rewriter | ||
from pytensor.tensor import ( | ||
broadcast_to, | ||
expand_dims, | ||
join, | ||
moveaxis, | ||
specify_shape, | ||
squeeze, | ||
) | ||
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor | ||
from pytensor.xtensor.rewriting.basic import register_lower_xtensor | ||
from pytensor.xtensor.shape import ( | ||
Concat, | ||
ExpandDims, | ||
Squeeze, | ||
Stack, | ||
Transpose, | ||
UnStack, | ||
) | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[Stack]) | ||
def lower_stack(fgraph, node): | ||
[x] = node.inputs | ||
batch_ndim = x.type.ndim - len(node.op.stacked_dims) | ||
stacked_axes = [ | ||
i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims | ||
] | ||
end = tuple(range(-len(stacked_axes), 0)) | ||
|
||
x_tensor = tensor_from_xtensor(x) | ||
x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end) | ||
if batch_ndim == (x.type.ndim - 1): | ||
# This happens when we stack a "single" dimension, in this case all we need is the transpose | ||
# Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename | ||
final_tensor = x_tensor_transposed | ||
else: | ||
final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1) | ||
final_tensor = x_tensor_transposed.reshape(final_shape) | ||
|
||
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) | ||
return [new_out] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[UnStack]) | ||
def lower_unstack(fgraph, node): | ||
x = node.inputs[0] | ||
unstacked_lengths = node.inputs[1:] | ||
axis_to_unstack = x.type.dims.index(node.op.old_dim_name) | ||
|
||
x_tensor = tensor_from_xtensor(x) | ||
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1]) | ||
final_tensor = x_tensor_transposed.reshape( | ||
(*x_tensor_transposed.shape[:-1], *unstacked_lengths) | ||
) | ||
# Reintroduce any static shape information that was lost during the reshape | ||
final_tensor = specify_shape(final_tensor, node.outputs[0].type.shape) | ||
|
||
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) | ||
return [new_out] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[Concat]) | ||
def lower_concat(fgraph, node): | ||
out_dims = node.outputs[0].type.dims | ||
concat_dim = node.op.dim | ||
concat_axis = out_dims.index(concat_dim) | ||
|
||
# Convert input XTensors to Tensors and align batch dimensions | ||
tensor_inputs = [] | ||
for inp in node.inputs: | ||
inp_dims = inp.type.dims | ||
order = [ | ||
inp_dims.index(out_dim) if out_dim in inp_dims else "x" | ||
for out_dim in out_dims | ||
] | ||
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) | ||
tensor_inputs.append(tensor_inp) | ||
|
||
# Broadcast non-concatenated dimensions of each input | ||
non_concat_shape = [None] * len(out_dims) | ||
for tensor_inp in tensor_inputs: | ||
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime | ||
# I'm running this as "shape_unsafe" to simplify the logic / returned graph | ||
for i, (bcast, sh) in enumerate( | ||
zip(tensor_inp.type.broadcastable, tensor_inp.shape) | ||
): | ||
if bcast or i == concat_axis or non_concat_shape[i] is not None: | ||
continue | ||
non_concat_shape[i] = sh | ||
|
||
assert non_concat_shape.count(None) == 1 | ||
|
||
bcast_tensor_inputs = [] | ||
for tensor_inp in tensor_inputs: | ||
# We modify the concat_axis in place, as we don't need the list anywhere else | ||
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis] | ||
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape)) | ||
|
||
joined_tensor = join(concat_axis, *bcast_tensor_inputs) | ||
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) | ||
return [new_out] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[Transpose]) | ||
def lower_transpose(fgraph, node): | ||
[x] = node.inputs | ||
# Use the final dimensions that were already computed in make_node | ||
out_dims = node.outputs[0].type.dims | ||
in_dims = x.type.dims | ||
|
||
# Compute the permutation based on the final dimensions | ||
perm = tuple(in_dims.index(d) for d in out_dims) | ||
x_tensor = tensor_from_xtensor(x) | ||
x_tensor_transposed = x_tensor.transpose(perm) | ||
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) | ||
return [new_out] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter([Squeeze]) | ||
def lower_squeeze(fgraph, node): | ||
"""Rewrite Squeeze to tensor.squeeze.""" | ||
[x] = node.inputs | ||
x_tensor = tensor_from_xtensor(x) | ||
x_dims = x.type.dims | ||
dims_to_remove = node.op.dims | ||
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove) | ||
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze) | ||
|
||
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) | ||
return [new_out] | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter([ExpandDims]) | ||
def lower_expand_dims(fgraph, node): | ||
"""Rewrite ExpandDims using tensor operations.""" | ||
x, size = node.inputs | ||
out = node.outputs[0] | ||
|
||
# Convert inputs to tensors | ||
x_tensor = tensor_from_xtensor(x) | ||
size_tensor = tensor_from_xtensor(size) | ||
|
||
# Get the new dimension name and position | ||
new_axis = 0 # Always insert at front | ||
|
||
# Use tensor operations | ||
if out.type.shape[0] == 1: | ||
# Simple case: just expand with size 1 | ||
result_tensor = expand_dims(x_tensor, new_axis) | ||
else: | ||
# Otherwise broadcast to the requested size | ||
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape)) | ||
|
||
# Preserve static shape information | ||
result_tensor = specify_shape(result_tensor, out.type.shape) | ||
|
||
# Convert result back to xtensor | ||
result = xtensor_from_tensor(result_tensor, dims=out.type.dims) | ||
return [result] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from pytensor.compile import optdb | ||
from pytensor.graph.rewriting.basic import NodeRewriter | ||
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase | ||
|
||
|
||
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False) | ||
|
||
optdb.register( | ||
"lower_xtensor", | ||
lower_xtensor_db, | ||
"fast_run", | ||
"fast_compile", | ||
"minimum_compile", | ||
position=0.1, | ||
) | ||
|
||
|
||
def register_lower_xtensor( | ||
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs | ||
): | ||
if isinstance(node_rewriter, str): | ||
|
||
def register(inner_rewriter: RewriteDatabase | NodeRewriter): | ||
return register_lower_xtensor( | ||
inner_rewriter, node_rewriter, *tags, **kwargs | ||
) | ||
|
||
return register | ||
|
||
else: | ||
name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore | ||
lower_xtensor_db.register( | ||
name, | ||
node_rewriter, | ||
"fast_run", | ||
"fast_compile", | ||
"minimum_compile", | ||
*tags, | ||
**kwargs, | ||
) | ||
return node_rewriter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from pytensor.graph import node_rewriter | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.tensor.random.utils import compute_batch_shape | ||
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor | ||
from pytensor.xtensor.rewriting.utils import register_lower_xtensor | ||
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[XElemwise]) | ||
def lower_elemwise(fgraph, node): | ||
out_dims = node.outputs[0].type.dims | ||
|
||
# Convert input XTensors to Tensors and align batch dimensions | ||
tensor_inputs = [] | ||
for inp in node.inputs: | ||
inp_dims = inp.type.dims | ||
order = [ | ||
inp_dims.index(out_dim) if out_dim in inp_dims else "x" | ||
for out_dim in out_dims | ||
] | ||
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) | ||
tensor_inputs.append(tensor_inp) | ||
|
||
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( | ||
*tensor_inputs, return_list=True | ||
) | ||
|
||
# Convert output Tensors to XTensors | ||
new_outs = [ | ||
xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs | ||
] | ||
return new_outs | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[XBlockwise]) | ||
def lower_blockwise(fgraph, node): | ||
op: XBlockwise = node.op | ||
batch_ndim = node.outputs[0].type.ndim - len(op.core_dims[1][0]) | ||
batch_dims = node.outputs[0].type.dims[:batch_ndim] | ||
|
||
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end | ||
tensor_inputs = [] | ||
for inp, core_dims in zip(node.inputs, op.core_dims[0]): | ||
inp_dims = inp.type.dims | ||
# Align the batch dims of the input, and place the core dims on the right | ||
batch_order = [ | ||
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" | ||
for batch_dim in batch_dims | ||
] | ||
core_order = [inp_dims.index(core_dim) for core_dim in core_dims] | ||
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) | ||
tensor_inputs.append(tensor_inp) | ||
|
||
signature = op.signature or getattr(op.core_op, "gufunc_signature", None) | ||
if signature is None: | ||
# Build a signature based on the core dimensions | ||
# The Op signature could be more strict, as core_dims will never be repeated, but no functionality depends greatly on it | ||
inputs_core_dims, outputs_core_dims = op.core_dims | ||
inputs_signature = ",".join( | ||
f"({', '.join(inp_core_dims)})" for inp_core_dims in inputs_core_dims | ||
) | ||
outputs_signature = ",".join( | ||
f"({', '.join(out_core_dims)})" for out_core_dims in outputs_core_dims | ||
) | ||
signature = f"{inputs_signature}->{outputs_signature}" | ||
tensor_op = Blockwise(core_op=op.core_op, signature=signature) | ||
tensor_outs = tensor_op(*tensor_inputs, return_list=True) | ||
|
||
# Convert output Tensors to XTensors | ||
new_outs = [ | ||
xtensor_from_tensor(tensor_out, dims=old_out.type.dims) | ||
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True) | ||
] | ||
return new_outs | ||
|
||
|
||
@register_lower_xtensor | ||
@node_rewriter(tracks=[XRV]) | ||
def lower_rv(fgraph, node): | ||
op: XRV = node.op | ||
core_op = op.core_op | ||
|
||
_, old_out = node.outputs | ||
rng, *extra_dim_lengths_and_params = node.inputs | ||
extra_dim_lengths = extra_dim_lengths_and_params[: len(op.extra_dims)] | ||
params = extra_dim_lengths_and_params[len(op.extra_dims) :] | ||
|
||
batch_ndim = old_out.type.ndim - len(op.core_dims[1]) | ||
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim] | ||
|
||
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end | ||
tensor_params = [] | ||
for inp, core_dims in zip(params, op.core_dims[0]): | ||
inp_dims = inp.type.dims | ||
# Align the batch dims of the input, and place the core dims on the right | ||
batch_order = [ | ||
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" | ||
for batch_dim in param_batch_dims | ||
] | ||
core_order = [inp_dims.index(core_dim) for core_dim in core_dims] | ||
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) | ||
tensor_params.append(tensor_inp) | ||
|
||
size = None | ||
if op.extra_dims: | ||
# RV size contains the lengths of all batch dimensions, including those coming from the parameters | ||
if tensor_params: | ||
param_batch_shape = tuple( | ||
compute_batch_shape(tensor_params, ndims_params=core_op.ndims_params) | ||
) | ||
else: | ||
param_batch_shape = () | ||
size = [*extra_dim_lengths, *param_batch_shape] | ||
|
||
# RVs are their own core Op | ||
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs | ||
|
||
# Convert output Tensors to XTensors | ||
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims) | ||
return [new_next_rng, new_out] |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
from itertools import chain | ||
|
||
import numpy as np | ||
|
||
from pytensor import scalar as ps | ||
from pytensor import shared | ||
from pytensor.graph import Apply, Op | ||
from pytensor.scalar import discrete_dtypes | ||
from pytensor.tensor import tensor | ||
from pytensor.tensor.random.op import RNGConsumerOp | ||
from pytensor.tensor.random.type import RandomType | ||
from pytensor.tensor.utils import ( | ||
get_static_shape_from_size_variables, | ||
) | ||
from pytensor.xtensor.basic import XOp | ||
from pytensor.xtensor.type import as_xtensor, xtensor | ||
|
||
|
||
def combine_dims_and_shape(inputs): | ||
dims_and_shape: dict[str, int | None] = {} | ||
for inp in inputs: | ||
for dim, dim_length in zip(inp.type.dims, inp.type.shape): | ||
if dim not in dims_and_shape: | ||
dims_and_shape[dim] = dim_length | ||
elif dim_length is not None: | ||
# Check for conflicting shapes | ||
if (dims_and_shape[dim] is not None) and ( | ||
dims_and_shape[dim] != dim_length | ||
): | ||
raise ValueError(f"Dimension {dim} has conflicting shapes") | ||
# Keep the non-None shape | ||
dims_and_shape[dim] = dim_length | ||
return dims_and_shape | ||
|
||
|
||
class XElemwise(XOp): | ||
__props__ = ("scalar_op",) | ||
|
||
def __init__(self, scalar_op): | ||
super().__init__() | ||
self.scalar_op = scalar_op | ||
|
||
def make_node(self, *inputs): | ||
inputs = [as_xtensor(inp) for inp in inputs] | ||
if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin): | ||
raise ValueError( | ||
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}" | ||
) | ||
|
||
dims_and_shape = combine_dims_and_shape(inputs) | ||
if dims_and_shape: | ||
output_dims, output_shape = zip(*dims_and_shape.items()) | ||
else: | ||
output_dims, output_shape = (), () | ||
|
||
dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs] | ||
output_dtypes = [ | ||
out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs | ||
] | ||
outputs = [ | ||
xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape) | ||
for output_dtype in output_dtypes | ||
] | ||
return Apply(self, inputs, outputs) | ||
|
||
|
||
class XBlockwise(XOp): | ||
__props__ = ("core_op", "core_dims") | ||
|
||
def __init__( | ||
self, | ||
core_op: Op, | ||
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]], | ||
signature: str | None = None, | ||
): | ||
super().__init__() | ||
self.core_op = core_op | ||
self.core_dims = core_dims | ||
self.signature = signature # Only used for lowering, not for validation | ||
|
||
def make_node(self, *inputs): | ||
inputs = [as_xtensor(i) for i in inputs] | ||
if len(inputs) != len(self.core_dims[0]): | ||
raise ValueError( | ||
f"Wrong number of inputs, expected {len(self.core_dims[0])}, got {len(inputs)}" | ||
) | ||
|
||
dims_and_shape = combine_dims_and_shape(inputs) | ||
|
||
core_inputs_dims, core_outputs_dims = self.core_dims | ||
core_input_dims_set = set(chain.from_iterable(core_inputs_dims)) | ||
batch_dims, batch_shape = zip( | ||
*((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set) | ||
) | ||
|
||
dummy_core_inputs = [] | ||
for inp, core_inp_dims in zip(inputs, core_inputs_dims): | ||
try: | ||
core_static_shape = [ | ||
inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims | ||
] | ||
except IndexError: | ||
raise ValueError( | ||
f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}" | ||
) | ||
dummy_core_inputs.append( | ||
tensor(dtype=inp.type.dtype, shape=core_static_shape) | ||
) | ||
core_node = self.core_op.make_node(*dummy_core_inputs) | ||
|
||
outputs = [ | ||
xtensor( | ||
dtype=core_out.type.dtype, | ||
shape=batch_shape + core_out.type.shape, | ||
dims=batch_dims + core_out_dims, | ||
) | ||
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims) | ||
] | ||
return Apply(self, inputs, outputs) | ||
|
||
|
||
class XRV(XOp, RNGConsumerOp): | ||
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics. | ||
Xarray does not offer random generators, so this class implements a new API. | ||
It mostly works like a gufunc (or XBlockwise), which specifies core dimensions for inputs and output, and | ||
enforces dim-based broadcasting between inputs and output. | ||
It differs from XBlockwise in a couple of ways: | ||
1. It is restricted to one sample output | ||
2. It takes a random generator as the first input and returns the consumed generator as the first output. | ||
3. It has the concept of extra dimensions, which determine extra batch dimensions of the output, that are not | ||
implied by batch dimensions of the parameters. | ||
""" | ||
|
||
default_output = 1 | ||
__props__ = ("core_op", "core_dims", "extra_dims") | ||
|
||
def __init__( | ||
self, | ||
core_op, | ||
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]], | ||
extra_dims: tuple[str, ...], | ||
): | ||
super().__init__() | ||
self.core_op = core_op | ||
inps_core_dims, out_core_dims = core_dims | ||
for operand_dims in (*inps_core_dims, out_core_dims): | ||
if len(set(operand_dims)) != len(operand_dims): | ||
raise ValueError(f"Operand has repeated dims {operand_dims}") | ||
self.core_dims = (tuple(i for i in inps_core_dims), tuple(out_core_dims)) | ||
if len(set(extra_dims)) != len(extra_dims): | ||
raise ValueError("size_dims must be unique") | ||
self.extra_dims = tuple(extra_dims) | ||
|
||
def update(self, node): | ||
# RNG input and update are the first input and output respectively | ||
return {node.inputs[0]: node.outputs[0]} | ||
|
||
def make_node(self, rng, *extra_dim_lengths_and_params): | ||
if rng is None: | ||
rng = shared(np.random.default_rng()) | ||
elif not isinstance(rng.type, RandomType): | ||
raise TypeError( | ||
"The type of rng should be an instance of RandomGeneratorType " | ||
) | ||
|
||
extra_dim_lengths = [ | ||
as_xtensor(dim_length).values | ||
for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)] | ||
] | ||
if not all( | ||
(dim_length.type.ndim == 0 and dim_length.type.dtype in discrete_dtypes) | ||
for dim_length in extra_dim_lengths | ||
): | ||
raise TypeError("All dimension lengths should be scalar discrete dtype.") | ||
|
||
params = [ | ||
as_xtensor(param) | ||
for param in extra_dim_lengths_and_params[len(self.extra_dims) :] | ||
] | ||
if len(params) != len(self.core_op.ndims_params): | ||
raise ValueError( | ||
f"Expected {len(self.core_op.ndims_params)} parameters + {len(self.extra_dims)} dim_lengths, " | ||
f"got {len(extra_dim_lengths_and_params)}" | ||
) | ||
|
||
param_core_dims, output_core_dims = self.core_dims | ||
input_core_dims_set = set(chain.from_iterable(param_core_dims)) | ||
|
||
# Check parameters don't have core dimensions they shouldn't have | ||
for param, core_param_dims in zip(params, param_core_dims): | ||
if invalid_core_dims := ( | ||
set(param.type.dims) - set(core_param_dims) | ||
).intersection(input_core_dims_set): | ||
raise ValueError( | ||
f"Parameter {param} has invalid core dimensions {sorted(invalid_core_dims)}" | ||
) | ||
|
||
extra_dims_and_shape = dict( | ||
zip( | ||
self.extra_dims, get_static_shape_from_size_variables(extra_dim_lengths) | ||
) | ||
) | ||
params_dims_and_shape = combine_dims_and_shape(params) | ||
|
||
# Check that no parameter dims conflict with size dims | ||
if conflict_dims := set(extra_dims_and_shape).intersection( | ||
params_dims_and_shape | ||
): | ||
raise ValueError( | ||
f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique." | ||
) | ||
|
||
batch_dims_and_shape = [ | ||
(dim, dim_length) | ||
for dim, dim_length in ( | ||
extra_dims_and_shape | params_dims_and_shape | ||
).items() | ||
if dim not in input_core_dims_set | ||
] | ||
if batch_dims_and_shape: | ||
batch_output_dims, batch_output_shape = zip(*batch_dims_and_shape) | ||
else: | ||
batch_output_dims, batch_output_shape = (), () | ||
|
||
dummy_core_inputs = [] | ||
for param, core_param_dims in zip(params, param_core_dims): | ||
try: | ||
core_static_shape = [ | ||
param.type.shape[param.type.dims.index(d)] for d in core_param_dims | ||
] | ||
except ValueError: | ||
raise ValueError( | ||
f"At least one core dim={core_param_dims} missing from input {param} with dims={param.type.dims}" | ||
) | ||
dummy_core_inputs.append( | ||
tensor(dtype=param.type.dtype, shape=core_static_shape) | ||
) | ||
core_node = self.core_op.make_node(rng, None, *dummy_core_inputs) | ||
|
||
if not len(core_node.outputs) == 2: | ||
raise NotImplementedError( | ||
"XRandomVariable only supports core ops with two outputs (rng, out)" | ||
) | ||
|
||
_, core_out = core_node.outputs | ||
out = xtensor( | ||
dtype=core_out.type.dtype, | ||
shape=batch_output_shape + core_out.type.shape, | ||
dims=batch_output_dims + output_core_dims, | ||
) | ||
|
||
return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# ruff: noqa: E402 | ||
import pytest | ||
|
||
|
||
pytest.importorskip("xarray") | ||
pytest.importorskip("xarray_einstats") | ||
|
||
import numpy as np | ||
from xarray import DataArray | ||
from xarray_einstats.linalg import ( | ||
cholesky as xr_cholesky, | ||
) | ||
from xarray_einstats.linalg import ( | ||
solve as xr_solve, | ||
) | ||
|
||
from pytensor.xtensor.linalg import cholesky, solve | ||
from pytensor.xtensor.type import xtensor | ||
from tests.xtensor.util import xr_assert_allclose, xr_function | ||
|
||
|
||
def test_cholesky(): | ||
x = xtensor("x", dims=("a", "batch", "b"), shape=(4, 3, 4)) | ||
y = cholesky(x, dims=["b", "a"]) | ||
assert y.type.dims == ("batch", "b", "a") | ||
assert y.type.shape == (3, 4, 4) | ||
|
||
fn = xr_function([x], y) | ||
rng = np.random.default_rng(25) | ||
x_ = rng.random(size=(3, 4, 4)) | ||
x_ = x_ @ x_.mT | ||
x_test = DataArray(x_.transpose(1, 0, 2), dims=x.type.dims) | ||
xr_assert_allclose( | ||
fn(x_test), | ||
xr_cholesky(x_test, dims=["b", "a"]), | ||
) | ||
|
||
|
||
def test_solve_vector_b(): | ||
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) | ||
b = xtensor("b", dims=("city", "planet"), shape=(None, 2)) | ||
x = solve(a, b, dims=["country", "city"]) | ||
assert x.type.dims == ("galaxy", "planet", "country") | ||
# Core Solve doesn't make use of the fact A must be square in the static shape | ||
assert x.type.shape == (1, 2, None) | ||
|
||
fn = xr_function([a, b], x) | ||
|
||
rng = np.random.default_rng(25) | ||
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) | ||
b_test = DataArray(rng.random(size=(4, 2)), dims=b.type.dims) | ||
|
||
xr_assert_allclose( | ||
fn(a_test, b_test), | ||
xr_solve(a_test, b_test, dims=["country", "city"]), | ||
) | ||
|
||
|
||
def test_solve_matrix_b(): | ||
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) | ||
b = xtensor("b", dims=("district", "city", "planet"), shape=(5, None, 2)) | ||
x = solve(a, b, dims=["country", "city", "district"]) | ||
assert x.type.dims == ("galaxy", "planet", "country", "district") | ||
# Core Solve doesn't make use of the fact A must be square in the static shape | ||
assert x.type.shape == (1, 2, None, 5) | ||
|
||
fn = xr_function([a, b], x) | ||
|
||
rng = np.random.default_rng(25) | ||
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) | ||
b_test = DataArray(rng.random(size=(5, 4, 2)), dims=b.type.dims) | ||
|
||
xr_assert_allclose( | ||
fn(a_test, b_test), | ||
xr_solve(a_test, b_test, dims=["country", "city", "district"]), | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
# ruff: noqa: E402 | ||
import pytest | ||
|
||
|
||
pytest.importorskip("xarray") | ||
|
||
import inspect | ||
|
||
import numpy as np | ||
from xarray import DataArray | ||
|
||
import pytensor.scalar as ps | ||
import pytensor.xtensor.math as pxm | ||
from pytensor import function | ||
from pytensor.scalar import ScalarOp | ||
from pytensor.xtensor.basic import rename | ||
from pytensor.xtensor.math import add, exp | ||
from pytensor.xtensor.type import xtensor | ||
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function | ||
|
||
|
||
def test_all_scalar_ops_are_wrapped(): | ||
# This ignores wrapper functions | ||
pxm_members = {name for name, _ in inspect.getmembers(pxm)} | ||
for name, op in inspect.getmembers(ps): | ||
if name in { | ||
"complex_from_polar", | ||
"inclosedrange", | ||
"inopenrange", | ||
"round_half_away_from_zero", | ||
"round_half_to_even", | ||
"scalar_abs", | ||
"scalar_maximum", | ||
"scalar_minimum", | ||
} or name.startswith("convert_to_"): | ||
# These are not regular numpy functions or are unusual alias | ||
continue | ||
if isinstance(op, ScalarOp) and name not in pxm_members: | ||
raise NotImplementedError(f"ScalarOp {name} not wrapped in xtensor.math") | ||
|
||
|
||
def test_scalar_case(): | ||
x = xtensor("x", dims=(), shape=()) | ||
y = xtensor("y", dims=(), shape=()) | ||
out = add(x, y) | ||
|
||
fn = function([x, y], out) | ||
|
||
x_test = DataArray(2.0, dims=()) | ||
y_test = DataArray(3.0, dims=()) | ||
np.testing.assert_allclose(fn(x_test.values, y_test.values), 5.0) | ||
|
||
|
||
def test_dimension_alignment(): | ||
x = xtensor("x", dims=("city", "country", "planet"), shape=(2, 3, 4)) | ||
y = xtensor( | ||
"y", | ||
dims=("galaxy", "country", "city"), | ||
shape=(5, 3, 2), | ||
) | ||
z = xtensor("z", dims=("universe",), shape=(1,)) | ||
out = add(x, y, z) | ||
assert out.type.dims == ("city", "country", "planet", "galaxy", "universe") | ||
|
||
fn = function([x, y, z], out) | ||
|
||
rng = np.random.default_rng(41) | ||
test_x, test_y, test_z = ( | ||
DataArray(rng.normal(size=inp.type.shape), dims=inp.type.dims) | ||
for inp in [x, y, z] | ||
) | ||
np.testing.assert_allclose( | ||
fn(test_x.values, test_y.values, test_z.values), | ||
(test_x + test_y + test_z).values, | ||
) | ||
|
||
|
||
def test_renamed_dimension_alignment(): | ||
x = xtensor("x", dims=("a", "b1", "b2"), shape=(2, 3, 3)) | ||
y = rename(x, b1="b2", b2="b1") | ||
z = rename(x, b2="b3") | ||
assert y.type.dims == ("a", "b2", "b1") | ||
assert z.type.dims == ("a", "b1", "b3") | ||
|
||
out1 = add(x, x) # self addition | ||
assert out1.type.dims == ("a", "b1", "b2") | ||
out2 = add(x, y) # transposed addition | ||
assert out2.type.dims == ("a", "b1", "b2") | ||
out3 = add(x, z) # outer addition | ||
assert out3.type.dims == ("a", "b1", "b2", "b3") | ||
|
||
fn = xr_function([x], [out1, out2, out3]) | ||
x_test = DataArray( | ||
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), | ||
dims=x.type.dims, | ||
) | ||
results = fn(x_test) | ||
expected_results = [ | ||
x_test + x_test, | ||
x_test + x_test.rename(b1="b2", b2="b1"), | ||
x_test + x_test.rename(b2="b3"), | ||
] | ||
for result, expected_result in zip(results, expected_results): | ||
xr_assert_allclose(result, expected_result) | ||
|
||
|
||
def test_chained_operations(): | ||
x = xtensor("x", dims=("city",), shape=(None,)) | ||
y = xtensor("y", dims=("country",), shape=(4,)) | ||
z = add(exp(x), exp(y)) | ||
assert z.type.dims == ("city", "country") | ||
assert z.type.shape == (None, 4) | ||
|
||
fn = function([x, y], z) | ||
|
||
x_test = DataArray(np.zeros(3), dims="city") | ||
y_test = DataArray(np.ones(4), dims="country") | ||
|
||
np.testing.assert_allclose( | ||
fn(x_test.values, y_test.values), | ||
(np.exp(x_test) + np.exp(y_test)).values, | ||
) | ||
|
||
|
||
def test_multiple_constant(): | ||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
out = exp(x * 2) + 2 | ||
|
||
fn = function([x], out) | ||
|
||
x_test = np.zeros((2, 3), dtype=x.type.dtype) | ||
res = fn(x_test) | ||
expected_res = np.exp(x_test * 2) + 2 | ||
np.testing.assert_allclose(res, expected_res) | ||
|
||
|
||
def test_cast(): | ||
x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32") | ||
yf64 = x.astype("float64") | ||
yi16 = x.astype("int16") | ||
ybool = x.astype("bool") | ||
|
||
fn = xr_function([x], [yf64, yi16, ybool]) | ||
x_test = xr_arange_like(x) | ||
res_f64, res_i16, res_bool = fn(x_test) | ||
xr_assert_allclose(res_f64, x_test.astype("float64")) | ||
xr_assert_allclose(res_i16, x_test.astype("int16")) | ||
xr_assert_allclose(res_bool, x_test.astype("bool")) | ||
|
||
yc64 = x.astype("complex64") | ||
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): | ||
yc64.astype("float64") | ||
|
||
|
||
def test_dot(): | ||
"""Test basic dot product operations.""" | ||
# Test matrix-vector dot product (with multiple-letter dim names) | ||
x = xtensor("x", dims=("aa", "bb"), shape=(2, 3)) | ||
y = xtensor("y", dims=("bb",), shape=(3,)) | ||
z = x.dot(y) | ||
fn = xr_function([x, y], z) | ||
|
||
x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb")) | ||
y_test = DataArray(np.ones(3), dims=("bb",)) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Test matrix-vector dot product with ellipsis | ||
z = x.dot(y, dim=...) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test, dim=...) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Test matrix-matrix dot product | ||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
y = xtensor("y", dims=("b", "c"), shape=(3, 4)) | ||
z = x.dot(y) | ||
fn = xr_function([x, y], z) | ||
|
||
x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b")) | ||
y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c")) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Test matrix-matrix dot product with string dim | ||
z = x.dot(y, dim="b") | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test, dim="b") | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Test matrix-matrix dot product with list of dims | ||
z = x.dot(y, dim=["b"]) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test, dim=["b"]) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Test matrix-matrix dot product with ellipsis | ||
z = x.dot(y, dim=...) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test, dim=...) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Test a case where there are two dimensions to sum over | ||
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) | ||
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5)) | ||
z = x.dot(y) | ||
fn = xr_function([x, y], z) | ||
|
||
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c")) | ||
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d")) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Same but with explicit dimensions | ||
z = x.dot(y, dim=["b", "c"]) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test, dim=["b", "c"]) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Same but with ellipses | ||
z = x.dot(y, dim=...) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test, dim=...) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Dot product with sum | ||
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c")) | ||
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d")) | ||
expected = x_test.dot(y_test, dim=("a", "b", "c")) | ||
|
||
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) | ||
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5)) | ||
z = x.dot(y, dim=("a", "b", "c")) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Dot product with sum in the middle | ||
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d")) | ||
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e")) | ||
expected = x_test.dot(y_test, dim=("b", "d")) | ||
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5)) | ||
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6)) | ||
z = x.dot(y, dim=("b", "d")) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Same but with first two dims | ||
expected = x_test.dot(y_test, dim=["a", "b"]) | ||
z = x.dot(y, dim=["a", "b"]) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Same but with last two | ||
expected = x_test.dot(y_test, dim=["d", "e"]) | ||
z = x.dot(y, dim=["d", "e"]) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Same but with every other dim | ||
expected = x_test.dot(y_test, dim=["a", "c", "e"]) | ||
z = x.dot(y, dim=["a", "c", "e"]) | ||
fn = xr_function([x, y], z) | ||
z_test = fn(x_test, y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
# Test symbolic shapes | ||
x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic | ||
y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic | ||
z = x.dot(y) | ||
fn = xr_function([x, y], z) | ||
x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) | ||
y_test = DataArray(np.ones((3, 4)), dims=("b", "c")) | ||
z_test = fn(x_test, y_test) | ||
expected = x_test.dot(y_test) | ||
xr_assert_allclose(z_test, expected) | ||
|
||
|
||
def test_dot_errors(): | ||
# No matching dimensions | ||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
y = xtensor("y", dims=("b", "c"), shape=(3, 4)) | ||
with pytest.raises(ValueError, match="Dimension e not found in either input"): | ||
x.dot(y, dim="e") | ||
|
||
# Concrete dimension size mismatches | ||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
y = xtensor("y", dims=("b", "c"), shape=(4, 5)) | ||
with pytest.raises( | ||
ValueError, | ||
match="Size of dim 'b' does not match", | ||
): | ||
x.dot(y) | ||
|
||
# Symbolic dimension size mismatches | ||
x = xtensor("x", dims=("a", "b"), shape=(2, None)) | ||
y = xtensor("y", dims=("b", "c"), shape=(None, 5)) | ||
z = x.dot(y) | ||
fn = xr_function([x, y], z) | ||
x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) | ||
y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) | ||
# Doesn't fail until the rewrite | ||
with pytest.raises(ValueError, match="not aligned"): | ||
fn(x_test, y_test) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# ruff: noqa: E402 | ||
import pytest | ||
|
||
|
||
pytest.importorskip("xarray") | ||
|
||
from pytensor.xtensor.type import xtensor | ||
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"] | ||
) | ||
@pytest.mark.parametrize( | ||
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:] | ||
) | ||
def test_reduction(method, dim): | ||
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7)) | ||
out = getattr(x, method)(dim=dim) | ||
|
||
fn = xr_function([x], out) | ||
x_test = xr_arange_like(x) | ||
|
||
xr_assert_allclose( | ||
fn(x_test), | ||
getattr(x_test, method)(dim=dim), | ||
) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# ruff: noqa: E402 | ||
import pytest | ||
|
||
|
||
pytest.importorskip("xarray") | ||
|
||
import numpy as np | ||
from xarray import DataArray | ||
|
||
from pytensor.graph.basic import equal_computations | ||
from pytensor.tensor import as_tensor, specify_shape, tensor | ||
from pytensor.xtensor import xtensor | ||
from pytensor.xtensor.type import XTensorType, as_xtensor | ||
|
||
|
||
def test_xtensortype(): | ||
x1 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3)) | ||
x2 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3)) | ||
x3 = XTensorType(dtype="float64", dims=("a", "b"), shape=(None, 3)) | ||
y1 = XTensorType(dtype="float64", dims=("c", "d"), shape=(4, 5)) | ||
z1 = XTensorType(dtype="float32", dims=("a", "b"), shape=(2, 3)) | ||
|
||
assert x1 == x2 and x1.is_super(x2) and x2.is_super(x1) | ||
assert x1 != x3 and not x1.is_super(x3) and x3.is_super(x1) | ||
assert x1 != y1 and not x1.is_super(y1) and not y1.is_super(x1) | ||
assert x1 != z1 and not x1.is_super(z1) and not z1.is_super(x1) | ||
|
||
|
||
def test_xtensortype_filter_variable(): | ||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
|
||
y1 = xtensor("y1", dims=("a", "b"), shape=(2, 3)) | ||
assert x.type.filter_variable(y1) is y1 | ||
|
||
y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2)) | ||
expected_y2 = y2.transpose() | ||
assert equal_computations([x.type.filter_variable(y2)], [expected_y2]) | ||
|
||
y3 = xtensor("y3", dims=("b", "a"), shape=(3, None)) | ||
expected_y3 = as_xtensor( | ||
specify_shape(y3.transpose().values, (2, 3)), dims=("a", "b") | ||
) | ||
assert equal_computations([x.type.filter_variable(y3)], [expected_y3]) | ||
|
||
# Cases that fail | ||
with pytest.raises(TypeError): | ||
y4 = xtensor("y4", dims=("a", "b"), shape=(3, 2)) | ||
x.type.filter_variable(y4) | ||
|
||
with pytest.raises(TypeError): | ||
y5 = xtensor("y5", dims=("a", "c"), shape=(2, 3)) | ||
x.type.filter_variable(y5) | ||
|
||
with pytest.raises(TypeError): | ||
y6 = xtensor("y6", dims=("a", "b", "c"), shape=(2, 3, 4)) | ||
x.type.filter_variable(y6) | ||
|
||
with pytest.raises(TypeError): | ||
y7 = xtensor("y7", dims=("a", "b"), shape=(2, 3), dtype="int32") | ||
x.type.filter_variable(y7) | ||
|
||
z1 = tensor("z1", shape=(2, None)) | ||
expected_z1 = as_xtensor(specify_shape(z1, (2, 3)), dims=("a", "b")) | ||
assert equal_computations([x.type.filter_variable(z1)], [expected_z1]) | ||
|
||
# Cases that fail | ||
with pytest.raises(TypeError): | ||
z2 = tensor("z2", shape=(3, 2)) | ||
x.type.filter_variable(z2) | ||
|
||
with pytest.raises(TypeError): | ||
z3 = tensor("z3", shape=(1, 2, 3)) | ||
x.type.filter_variable(z3) | ||
|
||
with pytest.raises(TypeError): | ||
z4 = tensor("z4", shape=(2, 3), dtype="int32") | ||
x.type.filter_variable(z4) | ||
|
||
|
||
def test_xtensor_constant(): | ||
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b"))) | ||
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3)) | ||
|
||
y = as_xtensor(np.ones((2, 3)), dims=("a", "b")) | ||
assert y.type == x.type | ||
assert x.signature() == y.signature() | ||
assert x.equals(y) | ||
x_eval = x.eval() | ||
assert isinstance(x.eval(), np.ndarray) | ||
np.testing.assert_array_equal(x_eval, y.eval(), strict=True) | ||
|
||
z = as_xtensor(np.ones((3, 2)), dims=("b", "a")) | ||
assert z.type != x.type | ||
assert z.signature() != x.signature() | ||
assert not x.equals(z) | ||
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True) | ||
|
||
|
||
def test_as_tensor(): | ||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
|
||
with pytest.raises( | ||
TypeError, | ||
match="PyTensor forbids automatic conversion of XTensorVariable to TensorVariable", | ||
): | ||
as_tensor(x) | ||
|
||
x_pt = as_tensor(x, allow_xtensor_conversion=True) | ||
assert equal_computations([x_pt], [x.values]) | ||
|
||
|
||
def test_minimum_compile(): | ||
from pytensor.compile.mode import Mode | ||
|
||
x = xtensor("x", dims=("a", "b"), shape=(2, 3)) | ||
y = x.transpose() | ||
minimum_mode = Mode(linker="py", optimizer="minimum_compile") | ||
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode) | ||
np.testing.assert_array_equal(result, np.ones((3, 2))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# ruff: noqa: E402 | ||
import pytest | ||
|
||
|
||
pytest.importorskip("xarray") | ||
|
||
import numpy as np | ||
from xarray import DataArray | ||
from xarray.testing import assert_allclose | ||
|
||
from pytensor import function | ||
from pytensor.xtensor.type import XTensorType | ||
|
||
|
||
def xr_function(*args, **kwargs): | ||
"""Compile and wrap a PyTensor function to return xarray DataArrays.""" | ||
fn = function(*args, **kwargs) | ||
symbolic_outputs = fn.maker.fgraph.outputs | ||
assert all( | ||
isinstance(out.type, XTensorType) for out in symbolic_outputs | ||
), "All outputs must be xtensor" | ||
|
||
def xfn(*xr_inputs): | ||
np_inputs = [ | ||
inp.values if isinstance(inp, DataArray) else inp for inp in xr_inputs | ||
] | ||
np_outputs = fn(*np_inputs) | ||
if not isinstance(np_outputs, tuple | list): | ||
return DataArray(np_outputs, dims=symbolic_outputs[0].type.dims) | ||
else: | ||
return tuple( | ||
DataArray(res, dims=out.type.dims) | ||
for res, out in zip(np_outputs, symbolic_outputs) | ||
) | ||
|
||
xfn.fn = fn | ||
return xfn | ||
|
||
|
||
def xr_assert_allclose(x, y, *args, **kwargs): | ||
# Assert that two xarray DataArrays are close, ignoring coordinates | ||
x = x.drop_vars(x.coords) | ||
y = y.drop_vars(y.coords) | ||
assert_allclose(x, y, *args, **kwargs) | ||
|
||
|
||
def xr_arange_like(x): | ||
return DataArray( | ||
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), | ||
dims=x.type.dims, | ||
) | ||
|
||
|
||
def xr_random_like(x, rng=None): | ||
if rng is None: | ||
rng = np.random.default_rng() | ||
|
||
return DataArray( | ||
rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new files need a license.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucianopaz can we bring your pre-commit hook over?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a separate PR, wouldn't be surprised if have files missing it in main