Skip to content

Add transpose() for labeled tensors #1427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

AllenDowney
Copy link

@AllenDowney AllenDowney commented May 27, 2025

Add transpose operation for labeled tensors

This PR implements the transpose operation for labeled tensors (XTensor), which was previously marked as not implemented in the test suite. The implementation follows the xarray-like API for labeled tensors and includes support for ellipsis in dimension permutations.

Changes

  1. Added Transpose XOp class in pytensor/xtensor/shape.py that handles dimension reordering while preserving labels
  2. Added lower_transpose rewrite rule in pytensor/xtensor/rewriting/shape.py to convert labeled tensor operations into regular tensor operations
  3. Added expand_ellipsis helper function to handle ellipsis expansion in dimension permutations
  4. Updated test_transpose in tests/xtensor/test_shape.py to verify the implementation

Features

  • Support for full transpose (reversing all dimensions)
  • Support for partial permutations
  • Support for ellipsis (...) in dimension permutations
  • Proper handling of dimension labels during transposition
  • Conversion to regular tensor operations for evaluation

Testing

The implementation passes all test cases, including:

  • Identity permutation
  • Full transpose
  • Empty permutation (equivalent to full transpose)
  • Swapping last two dimensions
  • Using ellipsis for partial permutations

Related

This PR is part of the larger labeled tensors feature (PR #1411) and implements one of the planned shape operations.


📚 Documentation preview 📚: https://pytensor--1427.org.readthedocs.build/en/1427/

@AllenDowney
Copy link
Author

@ricardoV94 Please take a look.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, I left some small comments. We should also add it as an XTensorVariable method like the other ones here:

# def swap_dims(self, *args, **kwargs):
# ...
#
# def expand_dims(self, *args, **kwargs):
# ...
#
# def squeeze(self):
# ...

And similarly https://docs.xarray.dev/en/latest/generated/xarray.DataArray.T.html#xarray.DataArray.T.

This way you can do x.transpose() and x.T. these methods should call the helper you created.

# Determine the permutation of axes
out_dims = node.op.dims
in_dims = x.type.dims
expanded_dims = expand_ellipsis(out_dims, in_dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to expand_ellipsis again, you have the ground truth in node.outputs[0].type.dims that you already computed in make_node

class Transpose(XOp):
__props__ = ("dims",)

def __init__(self, dims: tuple[str, ...]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the type hint is wrong, because dims can include ellipsis

Suggested change
def __init__(self, dims: tuple[str, ...]):
def __init__(self, dims: tuple[str | Ellipsis, ...]):



def transpose(x, *dims):
return Transpose(dims)(x)
Copy link
Member

@ricardoV94 ricardoV94 May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xarray has a missing_dims, that we could provide here as well: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.transpose.html

We already do that for isel:

def isel(
self,
indexers: dict[str, Any] | None = None,
drop: bool = False, # Unused by PyTensor
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
**indexers_kwargs,
):

@@ -73,6 +73,67 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return y


def expand_ellipsis(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check that there's at most one dims and raise otherwise?

@AllenDowney
Copy link
Author

@ricardoV94 I think I've addressed all of your comments -- please take another look.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good. Just some issue with the type hint, and because of that I took the liberty of leaving some other nitpicky comments. Feel free to not address them if you disagree / it's too cumbersome

@AllenDowney
Copy link
Author

This PR adds support for expanding dimensions in labeled tensors, similar to xarray's functionality. The key changes include:

  1. Added ExpandDims operation that allows adding new dimensions to a tensor while preserving the existing labeled dimensions
  2. Implemented lower_expand_dims rewrite that efficiently handles the dimension expansion using reshape operations
  3. Added corresponding tests to verify the behavior matches xarray's expand_dims functionality
  4. Ensured code style compliance and improved documentation

The implementation allows users to add new dimensions to their labeled tensors while maintaining the semantic meaning of the existing dimensions, which is particularly useful for broadcasting and reshaping operations in data analysis workflows.

@AllenDowney
Copy link
Author

@ricardoV94 I've added an implementation of Squeeze. Please take a look.


# Create new shape with the new dimension
new_shape = list(x.type.shape)
new_shape.append(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong, new dims by default are added on the left, but actually xarray allows specifying a non-default axis: https://docs.xarray.dev/en/stable/generated/xarray.DataArray.expand_dims.html

if self.dim not in x.type.dims:
raise ValueError(f"Dimension {self.dim} not found")
dim_idx = x.type.dims.index(self.dim)
if x.type.shape[dim_idx] != 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x.type.shape may also be not known None in which case it can be 1 or something else, it's fine to let it through in that case (it will fail at runtime if nto 1). Only if x.type.shape is not None and not 1 do we know it must be invalid

xr_assert_allclose(fn(x_test), x_test.transpose())


def test_expand_dims():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you had kept the template of comparing against xarray like the other tests, you would have noticed expand_dims is not doing the same as xarray

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, comparing against xarray makes perfect sense. I will do that in the next iteration.

xr_assert_allclose(res_multi, expected_res_multi)


def test_lower_squeeze():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is nonsensical

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor and I put that in while debugging -- and it did help. But if it doesn't make sense to keep it as a test, I'll remove it.

@ricardoV94
Copy link
Member

ricardoV94 commented May 29, 2025

@AllenDowney I picked and cleanup up the transpose related changes into #1430.

Regarding squeeze and expand_dims, which I suggest you tackle on a separate PR. Two notes:

  1. Use pytensor.tensor.squeeze / expand_dims instead of reshape. The reason for this is that reshape is a very general Op that is harder to reason symbolically about. We have many rewrites for squeeze/expand_dims, that don't apply to reshape because it's not always easy to know when the two are equivalent
  2. Note that xarray expand_dims allows specifying the size of the new dimension, it need not be length 1. You can do DataArray(np.zeros((2, 3)), dims=("a", "b")).expand_dims(c=5) . This makes sense because xarray does not broadcast existing dims, so they should have the "right" sizes when they are created. This will require a mix of expand_dims and broadcast_to when lowered. It also means expand_dims will be a bit more complex to allow the size of the expand_dims to be symbolic, and will be a bit more like what we are exploring in Add unstack for xtensors #1412. Feel free to split Squeeze from ExpandDims, since the former is simpler.

@ricardoV94 ricardoV94 closed this Jun 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants