-
Notifications
You must be signed in to change notification settings - Fork 135
Add transpose() for labeled tensors #1427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add transpose() for labeled tensors #1427
Conversation
@ricardoV94 Please take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, I left some small comments. We should also add it as an XTensorVariable
method like the other ones here:
pytensor/pytensor/xtensor/type.py
Lines 311 to 318 in e32d865
# def swap_dims(self, *args, **kwargs): | |
# ... | |
# | |
# def expand_dims(self, *args, **kwargs): | |
# ... | |
# | |
# def squeeze(self): | |
# ... |
And similarly https://docs.xarray.dev/en/latest/generated/xarray.DataArray.T.html#xarray.DataArray.T.
This way you can do x.transpose()
and x.T
. these methods should call the helper you created.
pytensor/xtensor/rewriting/shape.py
Outdated
# Determine the permutation of axes | ||
out_dims = node.op.dims | ||
in_dims = x.type.dims | ||
expanded_dims = expand_ellipsis(out_dims, in_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to expand_ellipsis again, you have the ground truth in node.outputs[0].type.dims
that you already computed in make_node
pytensor/xtensor/shape.py
Outdated
class Transpose(XOp): | ||
__props__ = ("dims",) | ||
|
||
def __init__(self, dims: tuple[str, ...]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the type hint is wrong, because dims can include ellipsis
def __init__(self, dims: tuple[str, ...]): | |
def __init__(self, dims: tuple[str | Ellipsis, ...]): |
pytensor/xtensor/shape.py
Outdated
|
||
|
||
def transpose(x, *dims): | ||
return Transpose(dims)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xarray has a missing_dims
, that we could provide here as well: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.transpose.html
We already do that for isel
:
pytensor/pytensor/xtensor/type.py
Lines 362 to 368 in e32d865
def isel( | |
self, | |
indexers: dict[str, Any] | None = None, | |
drop: bool = False, # Unused by PyTensor | |
missing_dims: Literal["raise", "warn", "ignore"] = "raise", | |
**indexers_kwargs, | |
): |
@@ -73,6 +73,67 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) | |||
return y | |||
|
|||
|
|||
def expand_ellipsis( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this check that there's at most one dims and raise otherwise?
… matching xarray behavior; update tests and rewrites accordingly
… matching xarray behavior; update tests and rewrites accordingly
@ricardoV94 I think I've addressed all of your comments -- please take another look. |
e32d865
to
d8fe0d1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good. Just some issue with the type hint, and because of that I took the liberty of leaving some other nitpicky comments. Feel free to not address them if you disagree / it's too cumbersome
d8fe0d1
to
29b954a
Compare
This PR adds support for expanding dimensions in labeled tensors, similar to xarray's functionality. The key changes include:
The implementation allows users to add new dimensions to their labeled tensors while maintaining the semantic meaning of the existing dimensions, which is particularly useful for broadcasting and reshaping operations in data analysis workflows. |
29b954a
to
5a7b23c
Compare
@ricardoV94 I've added an implementation of Squeeze. Please take a look. |
|
||
# Create new shape with the new dimension | ||
new_shape = list(x.type.shape) | ||
new_shape.append(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is wrong, new dims by default are added on the left, but actually xarray allows specifying a non-default axis: https://docs.xarray.dev/en/stable/generated/xarray.DataArray.expand_dims.html
if self.dim not in x.type.dims: | ||
raise ValueError(f"Dimension {self.dim} not found") | ||
dim_idx = x.type.dims.index(self.dim) | ||
if x.type.shape[dim_idx] != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x.type.shape may also be not known None
in which case it can be 1 or something else, it's fine to let it through in that case (it will fail at runtime if nto 1). Only if x.type.shape is not None and not 1 do we know it must be invalid
xr_assert_allclose(fn(x_test), x_test.transpose()) | ||
|
||
|
||
def test_expand_dims(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you had kept the template of comparing against xarray like the other tests, you would have noticed expand_dims is not doing the same as xarray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, comparing against xarray makes perfect sense. I will do that in the next iteration.
xr_assert_allclose(res_multi, expected_res_multi) | ||
|
||
|
||
def test_lower_squeeze(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is nonsensical
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor and I put that in while debugging -- and it did help. But if it doesn't make sense to keep it as a test, I'll remove it.
@AllenDowney I picked and cleanup up the transpose related changes into #1430. Regarding squeeze and expand_dims, which I suggest you tackle on a separate PR. Two notes:
|
Add transpose operation for labeled tensors
This PR implements the transpose operation for labeled tensors (XTensor), which was previously marked as not implemented in the test suite. The implementation follows the xarray-like API for labeled tensors and includes support for ellipsis in dimension permutations.
Changes
Transpose
XOp class inpytensor/xtensor/shape.py
that handles dimension reordering while preserving labelslower_transpose
rewrite rule inpytensor/xtensor/rewriting/shape.py
to convert labeled tensor operations into regular tensor operationsexpand_ellipsis
helper function to handle ellipsis expansion in dimension permutationstest_transpose
intests/xtensor/test_shape.py
to verify the implementationFeatures
Testing
The implementation passes all test cases, including:
Related
This PR is part of the larger labeled tensors feature (PR #1411) and implements one of the planned shape operations.
📚 Documentation preview 📚: https://pytensor--1427.org.readthedocs.build/en/1427/