Skip to content

Add unstack op #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytensor.tensor import moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
from pytensor.xtensor.shape import Stack
from pytensor.xtensor.shape import Stack, UnStack


@register_xcanonicalize
Expand All @@ -27,3 +27,19 @@ def lower_stack(fgraph, node):

new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[UnStack])
def lower_unstack(fgraph, node):
[x] = node.inputs
axis_to_unstack = x.type.dims.index(node.op.old_dim_name)

x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1])
final_tensor = x_tensor_transposed.reshape(
(*x_tensor_transposed.shape[:-1], *node.op.unstacked_lengths)
)

new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]
72 changes: 72 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,75 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
)
y = Stack(new_dim_name, tuple(stacked_dims))(y)
return y


class UnStack(XOp):
__props__ = ("old_dim_name", "unstacked_dims", "unstacked_lengths")

def __init__(
self,
old_dim_name: str,
unstacked_dims: tuple[str, ...],
unstacked_lengths: tuple[int, ...],
):
super().__init__()
if old_dim_name in unstacked_dims:
raise ValueError(
f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}"
)
if len(unstacked_dims) != len(unstacked_lengths):
raise ValueError(
"Tuples with unstacked dim names and lengths must have the same length "
f"but have {len(unstacked_dims)} and {len(unstacked_lengths)}"
)
if not unstacked_dims:
raise ValueError("Dims to unstack into can't be empty.")
if len(unstacked_dims) == 1:
raise ValueError("Only one dimension to unstack into, use rename instead")
self.old_dim_name = old_dim_name
self.unstacked_dims = unstacked_dims
self.unstacked_lengths = unstacked_lengths

def make_node(self, x):
x = as_xtensor(x)
if self.old_dim_name not in x.type.dims:
raise ValueError(
f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}"
)
if not set(self.unstacked_dims).isdisjoint(x.type.dims):
raise ValueError(
f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}"
)
if x.type.ndim == 1:
batch_dims, batch_shape = (), ()
else:
batch_dims, batch_shape = zip(
*(
(dim, shape)
for dim, shape in zip(x.type.dims, x.type.shape)
if dim != self.old_dim_name
)
)

output = xtensor(
dtype=x.type.dtype,
shape=(*batch_shape, *self.unstacked_lengths),
dims=(*batch_dims, *self.unstacked_dims),
)
return Apply(self, [x], [output])


def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
if dim is not None:
if dims:
raise ValueError(
"Cannot use both positional dim and keyword dims in unstack"
)
dims = dim

y = x
for old_dim_name, unstacked_dict in dims.items():
y = UnStack(
old_dim_name, tuple(unstacked_dict.keys()), tuple(unstacked_dict.values())
)(y)
return y
58 changes: 57 additions & 1 deletion tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from xarray import DataArray

from pytensor.xtensor.shape import stack
from pytensor.xtensor.shape import stack, unstack
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function

Expand Down Expand Up @@ -102,3 +102,59 @@ def test_multiple_stacks():
res = fn(x_test)
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
xr_assert_allclose(res[0], expected_res)


def test_unstack():
unstacked_dims = {"a": 2, "b": 3, "c": 5, "d": 7}
dims = ("abcd",)
x = xtensor("x", dims=dims, shape=(2 * 3 * 5 * 7,))
outs = [
unstack(
x,
abcd=(
{d: l for d, l in unstacked_dims.items() if d in dims_to_unstack}
| (
{}
if set(dims_to_unstack) == set(unstacked_dims)
else {
"other": int(
np.prod(
[
l
for d, l in unstacked_dims.items()
if d not in dims_to_unstack
]
)
)
}
)
),
)
for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2)
]
fn = xr_function([x], outs)
# we test through the complementary operation in xarray to avoid needing coords
# which are required for unstack. We end up with a subset of {a, b, c, d} and
# other after unstacking, so we create the fully unstacked dataarray
# and stack to create this extra "other" dimension as needed
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(
list(unstacked_dims.values())
),
dims=list(unstacked_dims.keys()),
)
res = fn(x_test)

expected_res = [
x_test.stack(
{}
if set(dims_to_unstack) == set(unstacked_dims)
else {"other": [d for d in unstacked_dims if d not in dims_to_unstack]}
)
for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2)
]
for res_i, expected_res_i in zip(res, expected_res):
assert res_i.shape == expected_res_i.shape
# the shapes are right but the "other" one has the elements in different order
# I think it is an issue with the test not the function but not sure
# xr_assert_allclose(res_i, expected_res_i)