Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a3db755

Browse files
committedJul 23, 2024·
POC named tensors
1 parent 981688c commit a3db755

File tree

8 files changed

+457
-0
lines changed

8 files changed

+457
-0
lines changed
 

‎pytensor/xtensor/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
as_xtensor_variable,
8+
xtensor,
9+
xtensor_constant,
10+
)
11+
12+
13+
warnings.warn("xtensor module is experimental and full of bugs")

‎pytensor/xtensor/basic.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytensor.scalar as ps
2+
import pytensor.xtensor as px
3+
from pytensor.graph import Apply, Op
4+
from pytensor.tensor import TensorType
5+
6+
7+
class TensorFromXTensor(Op):
8+
# TODO: May need mapping of named dims to positional dims?
9+
10+
def make_node(self, x) -> Apply:
11+
if not isinstance(x.type, px.XTensorType):
12+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
13+
output = TensorType(x.type.dtype, shape=x.type.shape)()
14+
return Apply(self, [x], [output])
15+
16+
def perform(self, node, inputs, output_storage) -> None:
17+
[x] = inputs
18+
output_storage[0][0] = x.copy()
19+
20+
21+
tensor_from_xtensor = TensorFromXTensor()
22+
23+
24+
class XTensorFromTensor(Op):
25+
__props__ = ("dims",)
26+
27+
def __init__(self, dims):
28+
super().__init__()
29+
self.dims = dims
30+
31+
def make_node(self, x) -> Apply:
32+
if not isinstance(x.type, TensorType):
33+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
34+
output = px.XTensorType(x.type.dtype, dims=self.dims, shape=x.type.shape)()
35+
return Apply(self, [x], [output])
36+
37+
def perform(self, node, inputs, output_storage) -> None:
38+
[x] = inputs
39+
output_storage[0][0] = x.copy()
40+
41+
42+
def xtensor_from_tensor(x, dims):
43+
return XTensorFromTensor(dims=dims)(x)
44+
45+
46+
class XElemwise(Op):
47+
__props__ = ("scalar_op",)
48+
49+
def __init__(self, scalar_op):
50+
super().__init__()
51+
self.scalar_op = scalar_op
52+
53+
def make_node(self, *inputs):
54+
inputs = [px.as_xtensor(inp) for inp in inputs]
55+
56+
# TODO: This ordering is different than what xarray does
57+
unique_dims: dict[str, int | None] = {}
58+
for inp in inputs:
59+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
60+
if dim not in unique_dims:
61+
unique_dims[dim] = dim_length
62+
elif dim_length is not None:
63+
# Check for conflicting shapes
64+
if (unique_dims[dim] is not None) and (
65+
unique_dims[dim] != dim_length
66+
):
67+
raise ValueError(f"Dimension {dim} has conflicting shapes")
68+
# Keep the non-None shape
69+
unique_dims[dim] = dim_length
70+
71+
dims, shape = zip(*sorted(unique_dims.items()))
72+
73+
# TODO: Fix dtype
74+
output_type = px.XTensorType("float64", dims=dims, shape=shape)
75+
outputs = [output_type() for _ in range(self.scalar_op.nout)]
76+
return Apply(self, inputs, outputs)
77+
78+
def perform(self, *args, **kwargs) -> None:
79+
raise NotImplementedError(
80+
"xtensor operations must be rewritten as tensor operations"
81+
)
82+
83+
84+
add = XElemwise(ps.add)
85+
exp = XElemwise(ps.exp)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic

‎pytensor/xtensor/rewriting/basic.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import expand_dims
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import (
5+
TensorFromXTensor,
6+
XElemwise,
7+
XTensorFromTensor,
8+
tensor_from_xtensor,
9+
xtensor_from_tensor,
10+
)
11+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
12+
13+
14+
@register_xcanonicalize
15+
@node_rewriter(tracks=[TensorFromXTensor])
16+
def useless_tensor_from_xtensor(fgraph, node):
17+
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
18+
[x] = node.inputs
19+
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
20+
return [x.owner.inputs[0]]
21+
22+
23+
@register_xcanonicalize
24+
@node_rewriter(tracks=[XTensorFromTensor])
25+
def useless_xtensor_from_tensor(fgraph, node):
26+
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
27+
[x] = node.inputs
28+
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
29+
return [x.owner.inputs[0]]
30+
31+
32+
@register_xcanonicalize
33+
@node_rewriter(tracks=[XElemwise])
34+
def xelemwise_to_elemwise(fgraph, node):
35+
# Convert inputs to TensorVariables and add broadcastable dims
36+
output_dims = node.outputs[0].type.dims
37+
38+
tensor_inputs = []
39+
for inp in node.inputs:
40+
inp_dims = inp.type.dims
41+
axis = [i for i, dim in enumerate(output_dims) if dim not in inp_dims]
42+
tensor_inp = tensor_from_xtensor(inp)
43+
tensor_inp = expand_dims(tensor_inp, axis)
44+
tensor_inputs.append(tensor_inp)
45+
46+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
47+
*tensor_inputs, return_list=True
48+
)
49+
50+
# TODO: copy_stack_trace
51+
new_outs = [
52+
xtensor_from_tensor(tensor_out, dims=output_dims) for tensor_out in tensor_outs
53+
]
54+
return new_outs

‎pytensor/xtensor/rewriting/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
optdb.register(
7+
"xcanonicalize",
8+
EquilibriumDB(ignore_newtrees=False),
9+
"fast_run",
10+
"fast_compile",
11+
"xtensor",
12+
position=0,
13+
)
14+
15+
16+
def register_xcanonicalize(
17+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
18+
):
19+
if isinstance(node_rewriter, str):
20+
21+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
22+
return register_xcanonicalize(
23+
inner_rewriter, node_rewriter, *tags, **kwargs
24+
)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

‎pytensor/xtensor/type.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
try:
2+
import xarray as xr
3+
4+
XARRAY_AVAILABLE = True
5+
except ModuleNotFoundError:
6+
XARRAY_AVAILABLE = False
7+
8+
from collections.abc import Sequence
9+
from typing import TypeVar
10+
11+
import numpy as np
12+
13+
from pytensor import _as_symbolic, config
14+
from pytensor import scalar as aes
15+
from pytensor.graph import Apply, Constant
16+
from pytensor.graph.basic import Variable
17+
from pytensor.graph.type import HasDataType, HasShape, Type
18+
from pytensor.tensor.utils import hash_from_ndarray
19+
from pytensor.utils import hash_from_code
20+
21+
22+
class XTensorType(Type, HasDataType, HasShape):
23+
"""A `Type` for Xtensors (Xarray-like tensors with dims)."""
24+
25+
__props__ = ("dtype", "shape", "dims")
26+
27+
def __init__(
28+
self,
29+
dtype: str | np.dtype,
30+
*,
31+
dims: Sequence[str],
32+
shape: Sequence[int | None] | None = None,
33+
name: str | None = None,
34+
):
35+
if dtype == "floatX":
36+
self.dtype = config.floatX
37+
else:
38+
if np.obj2sctype(dtype) is None:
39+
raise TypeError(f"Invalid dtype: {dtype}")
40+
41+
self.dtype = np.dtype(dtype).name
42+
43+
self.dims = tuple(dims)
44+
if shape is None:
45+
self.shape = (None,) * len(self.dims)
46+
else:
47+
self.shape = tuple(shape)
48+
self.name = name
49+
50+
def clone(
51+
self,
52+
dtype=None,
53+
dims=None,
54+
shape=None,
55+
**kwargs,
56+
):
57+
if dtype is None:
58+
dtype = self.dtype
59+
if dims is None:
60+
dims = self.dims
61+
if shape is None:
62+
shape = self.shape
63+
return type(self)(format, dtype, shape=shape, dims=dims, **kwargs)
64+
65+
def filter(self, value, strict=False, allow_downcast=None):
66+
# TODO: Implement this
67+
return value
68+
69+
if isinstance(value, Variable):
70+
raise TypeError(
71+
"Expected an array-like object, but found a Variable: "
72+
"maybe you are trying to call a function on a (possibly "
73+
"shared) variable instead of a numeric array?"
74+
)
75+
76+
if (
77+
isinstance(value, self.format_cls[self.format])
78+
and value.dtype == self.dtype
79+
):
80+
return value
81+
82+
if strict:
83+
raise TypeError(
84+
f"{value} is not sparse, or not the right dtype (is {value.dtype}, "
85+
f"expected {self.dtype})"
86+
)
87+
88+
# The input format could be converted here
89+
if allow_downcast:
90+
sp = self.format_cls[self.format](value, dtype=self.dtype)
91+
else:
92+
data = self.format_cls[self.format](value)
93+
up_dtype = aes.upcast(self.dtype, data.dtype)
94+
if up_dtype != self.dtype:
95+
raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}")
96+
sp = data.astype(up_dtype)
97+
98+
assert sp.format == self.format
99+
100+
return sp
101+
102+
def convert_variable(self, var):
103+
# TODO: Implement this
104+
return var
105+
res = super().convert_variable(var)
106+
107+
if res is None:
108+
return res
109+
110+
if not isinstance(res.type, type(self)):
111+
return None
112+
113+
if res.dims != self.dims:
114+
# TODO: Does this make sense?
115+
return None
116+
117+
return res
118+
119+
def __hash__(self):
120+
return hash(super().__hash__(), self.shape, self.dims)
121+
122+
def __repr__(self):
123+
# TODO: Add `?` for unknown shapes like `TensorType` does
124+
return f"XTensorType({self.dtype}, {self.dims}, {self.shape})"
125+
126+
def __eq__(self, other):
127+
res = super().__eq__(other)
128+
129+
if isinstance(res, bool):
130+
return res and self.dims == other.dims and self.shape == other.shape
131+
132+
return res
133+
134+
def is_super(self, otype):
135+
# TODO: Implement this
136+
return True
137+
138+
if not super().is_super(otype):
139+
return False
140+
141+
if self.dims == otype.dims:
142+
return True
143+
144+
return False
145+
146+
147+
def xtensor(
148+
name: str | None = None,
149+
*,
150+
dims: Sequence[str],
151+
shape: Sequence[int | None] | None = None,
152+
dtype: str | np.dtype = "floatX",
153+
):
154+
return XTensorType(dtype, dims=dims, shape=shape)(name=name)
155+
156+
157+
# class _x_tensor_py_operators
158+
159+
160+
class XTensorVariable(Variable):
161+
pass
162+
163+
# def __str__(self):
164+
# return f"{self.__class__.__name__}{{{self.format},{self.dtype}}}"
165+
166+
# def __repr__(self):
167+
# return str(self)
168+
169+
170+
class XTensorConstantSignature(tuple):
171+
def __eq__(self, other):
172+
if type(self) is not type(other):
173+
return False
174+
175+
(t0, d0), (t1, d1) = self, other
176+
if t0 != t1 or d0.shape != d1.shape:
177+
return False
178+
179+
return True
180+
181+
def __ne__(self, other):
182+
return not self == other
183+
184+
def __hash__(self):
185+
(a, b) = self
186+
return hash(type(self)) ^ hash(a) ^ hash(type(b))
187+
188+
def pytensor_hash(self):
189+
t, d = self
190+
return "".join([hash_from_ndarray(d)] + [hash_from_code(dim) for dim in t.dims])
191+
192+
193+
_XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType)
194+
195+
196+
class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]):
197+
def __init__(self, type: _XTensorTypeType, data, name=None):
198+
# TODO: Add checks that type and data are compatible
199+
Constant.__init__(self, type, data, name)
200+
201+
def signature(self):
202+
assert self.data is not None
203+
return XTensorConstantSignature((self.type, self.data))
204+
205+
206+
XTensorType.variable_type = XTensorVariable
207+
XTensorType.constant_type = XTensorConstant
208+
209+
210+
def xtensor_constant(x, name=None):
211+
if not isinstance(x, xr.DataArray):
212+
raise TypeError("xtensor.constant must be called on a Xarray DataArray")
213+
try:
214+
return XTensorConstant(
215+
XTensorType(dtype=x.dtype, dims=x.dims, shape=x.shape),
216+
x.values.copy(),
217+
name=name,
218+
)
219+
except TypeError:
220+
raise TypeError(f"Could not convert {x} to XTensorType")
221+
222+
223+
if XARRAY_AVAILABLE:
224+
225+
@_as_symbolic.register(xr.DataArray)
226+
def as_symbolic_xarray(x, **kwargs):
227+
return xtensor_constant(x, **kwargs)
228+
229+
230+
def as_xtensor_variable(x, name=None):
231+
if isinstance(x, Apply):
232+
if len(x.outputs) != 1:
233+
raise ValueError(
234+
"It is ambiguous which output of a "
235+
"multi-output Op has to be fetched.",
236+
x,
237+
)
238+
else:
239+
x = x.outputs[0]
240+
if isinstance(x, Variable):
241+
if not isinstance(x.type, XTensorType):
242+
raise TypeError(f"Variable type field must be a XTensorType, got {x.type}")
243+
return x
244+
try:
245+
return xtensor_constant(x, name=name)
246+
except TypeError as err:
247+
raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err
248+
249+
250+
as_xtensor = as_xtensor_variable

‎tests/xtensor/__init__.py

Whitespace-only changes.

‎tests/xtensor/test_basic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
3+
from pytensor import function
4+
from pytensor.xtensor.basic import add, exp
5+
from pytensor.xtensor.type import xtensor
6+
7+
8+
def test_add():
9+
x = xtensor("x", dims=("city",), shape=(None,))
10+
y = xtensor("y", dims=("country",), shape=(4,))
11+
z = add(exp(x), exp(y))
12+
assert z.type.dims == ("city", "country")
13+
assert z.type.shape == (None, 4)
14+
15+
fn = function([x, y], z)
16+
# fn.dprint(print_type=True)
17+
18+
np.testing.assert_allclose(
19+
fn(x=np.zeros(3), y=np.zeros(4)),
20+
np.full((3, 4), 2.0),
21+
)

0 commit comments

Comments
 (0)
Please sign in to comment.