Skip to content

first pass at unstack #1412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: labeled_tensors
Choose a base branch
from
Open

first pass at unstack #1412

wants to merge 4 commits into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 22, 2025

@OriolAbril I am opening a branch with your code here on PyTensor. It can be from your fork if you prefer, but that would have to be you doing it.

Copying your messages:

First pass at unstack. It is working already, need to sort out tests and double check the order in which unstack happens.

@ricardoV94 let me know if the PR should have been done in a different way and how the code looks. As I commented in the test code itself, tests currently pass but I am only checking matching shapes with xarray, the actual elements are different. I have to figure out if the idea I had of testing the complementary operation to circumvent the fact that xarray's unstack needs coordinates can't actually be used or if I am inverting some stack(new_dim=["a", "b"] while the other has ["b", "a"].


📚 Documentation preview 📚: https://pytensor--1412.org.readthedocs.build/en/1412/

Comment on lines +112 to +132
unstack(
x,
abcd=(
{d: l for d, l in unstacked_dims.items() if d in dims_to_unstack}
| (
{}
if set(dims_to_unstack) == set(unstacked_dims)
else {
"other": int(
np.prod(
[
l
for d, l in unstacked_dims.items()
if d not in dims_to_unstack
]
)
)
}
)
),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hard for me to read

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already hard to follow when I wrote it, now after the formatting it is a nightmare. I'll try to simplify things a bit tomorrow.

Copy link
Member Author

@ricardoV94 ricardoV94 May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I get what you were trying to do with the test (60% confidence), and I think it has no parallel to what our poor-mans unstack can do. We basically can only unstack "consecutive dimensions", whereas xarray will always know what a bunch of stacked dimensions correspond to, and can unstack "non-consecutive/arbitrarily ordered" dimensions.

I think for our purposes we want to always get an identity if we do transpose(unstack(stack(new_dim=stacked_dims), new_dim=original_stacked_dims), original_dims), whereoriginal_stacked_dims contains the same dims, in the same order and with the same sizes.

I added a test more like that, that maybe we can parametrize with the powerset approach?

@ricardoV94 ricardoV94 mentioned this pull request May 22, 2025
10 tasks
# xr_assert_allclose(res_i, expected_res_i)


def test_unstack_simple():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OriolAbril I added a simple just test to convince me things look correct and they do. Doesn't mean to replace your more exhaustive test and we can remove it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It is potentially more exhaustive but seeing this makes me yet a bit more convinced the issue is in the test and not the function so the complex one might need some rethinking.



class UnStack(XOp):
__props__ = ("old_dim_name", "unstacked_dims", "unstacked_lengths")
Copy link
Member Author

@ricardoV94 ricardoV94 May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like nothing requires "unstacked_lengths" to be constant/non-symbolic. So we could parametrize this Op just with ("old_dim_name", "unstacked_dims") and pass "unstacked_lengths" to make_node. We can convert those to scalar TensorVariables as_tensor(x, ndim=0) and check that the dtype is integer.

Everything in the rewrite with reshape would work the same, but we would extract them from node.inputs[1:]

This will allow stuff like:

x = xtensor(dims=("a", "b", "c"))
y = stack(x, bc=("b", "c"))
# do something with stacked y
z = unstack(y, bc=dict(b=x.sizes["b"], c=x.sizes["c"]))

Without the user having to pre-commit to static shapes for b, c

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointers, I'll try to make the updates

@ricardoV94 ricardoV94 mentioned this pull request May 22, 2025
29 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants