Skip to content

Commit 2d908df

Browse files
committed
POC named tensors
1 parent 5ffe17a commit 2d908df

File tree

12 files changed

+710
-0
lines changed

12 files changed

+710
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
as_xtensor_variable,
8+
xtensor,
9+
xtensor_constant,
10+
)
11+
12+
13+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from itertools import chain
2+
3+
import pytensor.scalar as ps
4+
from pytensor.graph import Apply, Op
5+
from pytensor.tensor import TensorType, tensor
6+
from pytensor.tensor.utils import _parse_gufunc_signature
7+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
8+
9+
10+
class XOp(Op):
11+
"""A base class for XOps that shouldn't be materialized"""
12+
13+
def perform(self, node, inputs, outputs):
14+
raise NotImplementedError(
15+
"xtensor operations must be rewritten as tensor operations"
16+
)
17+
18+
19+
class TensorFromXTensor(Op):
20+
view_map = {0: [0]}
21+
22+
def make_node(self, x) -> Apply:
23+
if not isinstance(x.type, XTensorType):
24+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
25+
output = TensorType(x.type.dtype, shape=x.type.shape)()
26+
return Apply(self, [x], [output])
27+
28+
def perform(self, node, inputs, output_storage) -> None:
29+
[x] = inputs
30+
output_storage[0][0] = x
31+
32+
33+
tensor_from_xtensor = TensorFromXTensor()
34+
35+
36+
class XTensorFromTensor(Op):
37+
view_map = {0: [0]}
38+
__props__ = ("dims",)
39+
40+
def __init__(self, dims):
41+
super().__init__()
42+
self.dims = dims
43+
44+
def make_node(self, x) -> Apply:
45+
if not isinstance(x.type, TensorType):
46+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
47+
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
48+
return Apply(self, [x], [output])
49+
50+
def perform(self, node, inputs, output_storage) -> None:
51+
[x] = inputs
52+
output_storage[0][0] = x
53+
54+
55+
def xtensor_from_tensor(x, dims):
56+
return XTensorFromTensor(dims=dims)(x)
57+
58+
59+
class XElemwise(XOp):
60+
__props__ = ("scalar_op",)
61+
62+
def __init__(self, scalar_op):
63+
super().__init__()
64+
self.scalar_op = scalar_op
65+
66+
def make_node(self, *inputs):
67+
inputs = [as_xtensor(inp) for inp in inputs]
68+
if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin):
69+
raise ValueError(
70+
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
71+
)
72+
73+
dims_and_shape: dict[str, int | None] = {}
74+
for inp in inputs:
75+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
76+
if dim not in dims_and_shape:
77+
dims_and_shape[dim] = dim_length
78+
elif dim_length is not None:
79+
# Check for conflicting shapes
80+
if (dims_and_shape[dim] is not None) and (
81+
dims_and_shape[dim] != dim_length
82+
):
83+
raise ValueError(f"Dimension {dim} has conflicting shapes")
84+
# Keep the non-None shape
85+
dims_and_shape[dim] = dim_length
86+
87+
output_dims, output_shape = zip(*dims_and_shape.items())
88+
89+
dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs]
90+
output_dtypes = [
91+
out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs
92+
]
93+
outputs = [
94+
xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape)
95+
for output_dtype in output_dtypes
96+
]
97+
return Apply(self, inputs, outputs)
98+
99+
100+
class XBlockwise(XOp):
101+
__props__ = ("core_op", "signature", "core_dims")
102+
103+
def __init__(
104+
self,
105+
core_op: Op,
106+
signature: str,
107+
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]],
108+
):
109+
super().__init__()
110+
self.core_op = core_op
111+
self.signature = signature
112+
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
113+
self.core_dims = core_dims
114+
115+
def make_node(self, *inputs):
116+
inputs = [as_xtensor(i) for i in inputs]
117+
if len(inputs) != len(self.inputs_sig):
118+
raise ValueError(
119+
f"Wrong number of inputs, expected {len(self.inputs_sig)}, got {len(inputs)}"
120+
)
121+
122+
dims_and_shape: dict[str, int | None] = {}
123+
for inp in inputs:
124+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
125+
if dim not in dims_and_shape:
126+
dims_and_shape[dim] = dim_length
127+
elif dim_length is not None:
128+
# Check for conflicting shapes
129+
if (dims_and_shape[dim] is not None) and (
130+
dims_and_shape[dim] != dim_length
131+
):
132+
raise ValueError(f"Dimension {dim} has conflicting shapes")
133+
# Keep the non-None shape
134+
dims_and_shape[dim] = dim_length
135+
136+
core_inputs_dims, core_outputs_dims = self.core_dims
137+
# TODO: Avoid intermediate dict
138+
core_dims = set(chain.from_iterable(core_inputs_dims))
139+
batched_dims_and_shape = {
140+
k: v for k, v in dims_and_shape.items() if k not in core_dims
141+
}
142+
batch_dims, batch_shape = zip(*batched_dims_and_shape.items())
143+
144+
dummy_core_inputs = []
145+
for inp, core_inp_dims in zip(inputs, core_inputs_dims):
146+
try:
147+
core_static_shape = [
148+
inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims
149+
]
150+
except IndexError:
151+
raise ValueError(
152+
f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}"
153+
)
154+
dummy_core_inputs.append(
155+
tensor(dtype=inp.type.dtype, shape=core_static_shape)
156+
)
157+
core_node = self.core_op.make_node(*dummy_core_inputs)
158+
159+
outputs = [
160+
xtensor(
161+
dtype=core_out.type.dtype,
162+
shape=batch_shape + core_out.type.shape,
163+
dims=batch_dims + core_out_dims,
164+
)
165+
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims)
166+
]
167+
return Apply(self, inputs, outputs)

pytensor/xtensor/linalg.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.tensor.slinalg import Solve
4+
from pytensor.xtensor import as_xtensor
5+
from pytensor.xtensor.basic import XBlockwise
6+
7+
8+
def solve(
9+
a,
10+
b,
11+
dims: Sequence[str],
12+
assume_a="gen",
13+
lower: bool = False,
14+
check_finite: bool = False,
15+
):
16+
a, b = as_xtensor(a), as_xtensor(b)
17+
if len(dims) == 2:
18+
b_ndim = 1
19+
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
20+
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
21+
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
22+
output_core_dims = ((m2_dim,),)
23+
elif len(dims) == 3:
24+
b_ndim = 2
25+
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
26+
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
27+
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
28+
output_core_dims = (
29+
(
30+
m2_dim,
31+
n_dim,
32+
),
33+
)
34+
else:
35+
raise ValueError("Solve dims must have length 2 or 3")
36+
37+
core_op = Solve(
38+
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
39+
)
40+
x_op = XBlockwise(
41+
core_op,
42+
signature=core_op.gufunc_signature,
43+
core_dims=(input_core_dims, output_core_dims),
44+
)
45+
return x_op(a, b)

pytensor/xtensor/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import inspect
2+
import sys
3+
4+
import pytensor.scalar as ps
5+
from pytensor.scalar import ScalarOp
6+
from pytensor.xtensor.basic import XElemwise
7+
8+
9+
this_module = sys.modules[__name__]
10+
11+
12+
def get_all_scalar_ops():
13+
"""
14+
Find all scalar operations in the pytensor.scalar module that can be wrapped with XElemwise.
15+
16+
Returns:
17+
dict: A dictionary mapping operation names to XElemwise instances
18+
"""
19+
result = {}
20+
21+
# Get all module members
22+
for name, obj in inspect.getmembers(ps):
23+
# Check if the object is a scalar op (has make_node method and is not an abstract class)
24+
if isinstance(obj, ScalarOp):
25+
result[name] = XElemwise(obj)
26+
27+
return result
28+
29+
30+
for name, op in get_all_scalar_ops().items():
31+
setattr(this_module, name, op)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.xtensor.rewriting.basic

pytensor/xtensor/rewriting/basic.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor.blockwise import Blockwise
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import (
5+
TensorFromXTensor,
6+
XBlockwise,
7+
XElemwise,
8+
XTensorFromTensor,
9+
tensor_from_xtensor,
10+
xtensor_from_tensor,
11+
)
12+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
13+
14+
15+
@register_xcanonicalize
16+
@node_rewriter(tracks=[TensorFromXTensor])
17+
def useless_tensor_from_xtensor(fgraph, node):
18+
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
19+
[x] = node.inputs
20+
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
21+
return [x.owner.inputs[0]]
22+
23+
24+
@register_xcanonicalize
25+
@node_rewriter(tracks=[XTensorFromTensor])
26+
def useless_xtensor_from_tensor(fgraph, node):
27+
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
28+
[x] = node.inputs
29+
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
30+
return [x.owner.inputs[0]]
31+
32+
33+
@register_xcanonicalize
34+
@node_rewriter(tracks=[XElemwise])
35+
def xelemwise_to_elemwise(fgraph, node):
36+
out_dims = node.outputs[0].type.dims
37+
38+
# Convert input XTensors to Tensors and align batch dimensions
39+
tensor_inputs = []
40+
for inp in node.inputs:
41+
inp_dims = inp.type.dims
42+
order = [
43+
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
44+
for out_dim in out_dims
45+
]
46+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
47+
tensor_inputs.append(tensor_inp)
48+
49+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
50+
*tensor_inputs, return_list=True
51+
)
52+
53+
# Convert output Tensors to XTensors
54+
new_outs = [
55+
xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs
56+
]
57+
return new_outs
58+
59+
60+
@register_xcanonicalize
61+
@node_rewriter(tracks=[XBlockwise])
62+
def xblockwise_to_blockwise(fgraph, node):
63+
op: XBlockwise = node.op
64+
batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0])
65+
batch_dims = node.outputs[0].type.dims[:batch_ndim]
66+
67+
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
68+
tensor_inputs = []
69+
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
70+
inp_dims = inp.type.dims
71+
# Align the batch dims of the input, and place the core dims on the right
72+
batch_order = [
73+
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
74+
for batch_dim in batch_dims
75+
]
76+
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
77+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
78+
tensor_inputs.append(tensor_inp)
79+
80+
tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature)
81+
tensor_outs = tensor_op(*tensor_inputs, return_list=True)
82+
83+
# Convert output Tensors to XTensors
84+
new_outs = [
85+
xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
86+
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
87+
]
88+
return new_outs

pytensor/xtensor/rewriting/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
optdb.register(
7+
"xcanonicalize",
8+
EquilibriumDB(ignore_newtrees=False),
9+
"fast_run",
10+
"fast_compile",
11+
"xtensor",
12+
position=0,
13+
)
14+
15+
16+
def register_xcanonicalize(
17+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
18+
):
19+
if isinstance(node_rewriter, str):
20+
21+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
22+
return register_xcanonicalize(
23+
inner_rewriter, node_rewriter, *tags, **kwargs
24+
)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

0 commit comments

Comments
 (0)