Skip to content

Commit 7da9935

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement squeeze for XTensorVariables
1 parent c263d7f commit 7da9935

File tree

4 files changed

+235
-3
lines changed

4 files changed

+235
-3
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
2+
from pytensor.tensor import (
3+
broadcast_to,
4+
join,
5+
moveaxis,
6+
specify_shape,
7+
squeeze,
8+
)
39
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
410
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
5-
from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
11+
from pytensor.xtensor.shape import (
12+
Concat,
13+
Squeeze,
14+
Stack,
15+
Transpose,
16+
UnStack,
17+
)
618

719

820
@register_lower_xtensor
@@ -105,3 +117,18 @@ def lower_transpose(fgraph, node):
105117
x_tensor_transposed = x_tensor.transpose(perm)
106118
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
107119
return [new_out]
120+
121+
122+
@register_lower_xtensor
123+
@node_rewriter([Squeeze])
124+
def local_squeeze_reshape(fgraph, node):
125+
"""Rewrite Squeeze to tensor.squeeze."""
126+
[x] = node.inputs
127+
x_tensor = tensor_from_xtensor(x)
128+
x_dims = x.type.dims
129+
dims_to_remove = node.op.dims
130+
axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove)
131+
x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze)
132+
133+
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,82 @@ def make_node(self, *inputs):
301301

302302
def concat(xtensors, dim: str):
303303
return Concat(dim=dim)(*xtensors)
304+
305+
306+
class Squeeze(XOp):
307+
"""Remove specified dimensions from an XTensorVariable.
308+
309+
Only dimensions that are known statically to be size 1 will be removed.
310+
Symbolic dimensions must be explicitly specified, and are assumed safe.
311+
312+
Parameters
313+
----------
314+
dim : tuple of str
315+
The names of the dimensions to remove.
316+
"""
317+
318+
__props__ = ("dims",)
319+
320+
def __init__(self, dims):
321+
self.dims = tuple(sorted(set(dims)))
322+
323+
def make_node(self, x):
324+
x = as_xtensor(x)
325+
326+
# Validate that dims exist and are size-1 if statically known
327+
dims_to_remove = []
328+
x_dims = x.type.dims
329+
x_shape = x.type.shape
330+
for d in self.dims:
331+
if d not in x_dims:
332+
raise ValueError(f"Dimension {d} not found in {x.type.dims}")
333+
idx = x_dims.index(d)
334+
dim_size = x_shape[idx]
335+
if dim_size is not None and dim_size != 1:
336+
raise ValueError(f"Dimension {d} has static size {dim_size}, not 1")
337+
dims_to_remove.append(idx)
338+
339+
new_dims = tuple(
340+
d for i, d in enumerate(x.type.dims) if i not in dims_to_remove
341+
)
342+
new_shape = tuple(
343+
s for i, s in enumerate(x.type.shape) if i not in dims_to_remove
344+
)
345+
346+
out = xtensor(
347+
dtype=x.type.dtype,
348+
shape=new_shape,
349+
dims=new_dims,
350+
)
351+
return Apply(self, [x], [out])
352+
353+
354+
def squeeze(x, dim=None):
355+
"""Remove dimensions of size 1 from an XTensorVariable.
356+
357+
Parameters
358+
----------
359+
x : XTensorVariable
360+
The input tensor
361+
dim : str or None or iterable of str, optional
362+
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
363+
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
364+
365+
Returns
366+
-------
367+
XTensorVariable
368+
A new tensor with the specified dimension(s) removed.
369+
"""
370+
x = as_xtensor(x)
371+
372+
if dim is None:
373+
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
374+
elif isinstance(dim, str):
375+
dims = (dim,)
376+
else:
377+
dims = tuple(dim)
378+
379+
if not dims:
380+
return x # no-op if nothing to squeeze
381+
382+
return Squeeze(dims=dims)(x)

pytensor/xtensor/type.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,16 @@ def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
471471
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
472472
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
473473

474+
def squeeze(
475+
self,
476+
dim: Sequence[str] | str | None = None,
477+
drop: bool = False,
478+
axis: int | Sequence[int] | None = None,
479+
):
480+
if axis is not None:
481+
raise NotImplementedError("Squeeze with axis not Implemented")
482+
return px.shape.squeeze(self, dim)
483+
474484
# ndarray methods
475485
# https://docs.xarray.dev/en/latest/api.html#id7
476486
def clip(self, min, max):

tests/xtensor/test_shape.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11+
import pytest
1112
from xarray import DataArray
1213
from xarray import concat as xr_concat
1314

14-
from pytensor.xtensor.shape import concat, stack, transpose, unstack
15+
from pytensor.xtensor.shape import (
16+
concat,
17+
squeeze,
18+
stack,
19+
transpose,
20+
unstack,
21+
)
1522
from pytensor.xtensor.type import xtensor
1623
from tests.xtensor.util import (
1724
xr_arange_like,
@@ -21,6 +28,9 @@
2128
)
2229

2330

31+
pytest.importorskip("xarray")
32+
33+
2434
def powerset(iterable, min_group_size=0):
2535
"Subsequences of the iterable from shortest to longest."
2636
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
@@ -253,3 +263,109 @@ def test_concat_scalar():
253263
res = fn(x1_test, x2_test)
254264
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
255265
xr_assert_allclose(res, expected_res)
266+
267+
268+
def test_squeeze_explicit_dims():
269+
"""Test squeeze with explicit dimension(s)."""
270+
271+
# Single dimension
272+
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
273+
y1 = squeeze(x1, "country")
274+
fn1 = xr_function([x1], y1)
275+
x1_test = xr_arange_like(x1)
276+
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country"))
277+
278+
# Multiple dimensions
279+
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3))
280+
y2 = squeeze(x2, ["b", "c"])
281+
fn2 = xr_function([x2], y2)
282+
x2_test = xr_arange_like(x2)
283+
xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"]))
284+
285+
# Order independence
286+
x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1))
287+
y3a = squeeze(x3, ["b", "c"])
288+
y3b = squeeze(x3, ["c", "b"])
289+
fn3a = xr_function([x3], y3a)
290+
fn3b = xr_function([x3], y3b)
291+
x3_test = xr_arange_like(x3)
292+
xr_assert_allclose(fn3a(x3_test), fn3b(x3_test))
293+
294+
# Redundant dimensions
295+
y3c = squeeze(x3, ["b", "b"])
296+
fn3c = xr_function([x3], y3c)
297+
xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"]))
298+
299+
# Empty list = no-op
300+
y3d = squeeze(x3, [])
301+
fn3d = xr_function([x3], y3d)
302+
xr_assert_allclose(fn3d(x3_test), x3_test)
303+
304+
305+
def test_squeeze_implicit_dims():
306+
"""Test squeeze with implicit dim=None (all size-1 dimensions)."""
307+
308+
# All dimensions size 1
309+
x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1))
310+
y1 = squeeze(x1)
311+
fn1 = xr_function([x1], y1)
312+
x1_test = xr_arange_like(x1)
313+
xr_assert_allclose(fn1(x1_test), x1_test.squeeze())
314+
315+
# No dimensions size 1 = no-op
316+
x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4))
317+
y2 = squeeze(x2)
318+
fn2 = xr_function([x2], y2)
319+
x2_test = xr_arange_like(x2)
320+
xr_assert_allclose(fn2(x2_test), x2_test)
321+
322+
# Symbolic shape where runtime shape is 1 → should squeeze
323+
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
324+
y3 = squeeze(x3, "b")
325+
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3)))
326+
fn3 = xr_function([x3], y3)
327+
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b"))
328+
329+
# Mixed static + symbolic shapes, where symbolic shape is 1
330+
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
331+
y4 = squeeze(x4, "b")
332+
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
333+
fn4 = xr_function([x4], y4)
334+
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))
335+
336+
"""
337+
This test documents that we intentionally don't squeeze dimensions with symbolic shapes
338+
(static_shape=None) even when they are 1 at runtime, while xarray does squeeze them.
339+
"""
340+
# Create a tensor with a symbolic dimension that will be 1 at runtime
341+
x = xtensor("x", dims=("a", "b", "c")) # shape unknown
342+
y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions
343+
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3)))
344+
fn = xr_function([x], y)
345+
res = fn(x_test)
346+
347+
# Our implementation should not squeeze the symbolic dimension
348+
assert "b" in res.dims
349+
# While xarray would squeeze it
350+
assert "b" not in x_test.squeeze().dims
351+
352+
353+
def test_squeeze_errors():
354+
"""Test error cases for squeeze."""
355+
356+
# Non-existent dimension
357+
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
358+
with pytest.raises(ValueError, match="Dimension .* not found"):
359+
squeeze(x1, "time")
360+
361+
# Dimension size > 1
362+
with pytest.raises(ValueError, match="has static size .* not 1"):
363+
squeeze(x1, "city")
364+
365+
# Symbolic shape: dim is not 1 at runtime → should raise
366+
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
367+
y2 = squeeze(x2, "b")
368+
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3)))
369+
fn2 = xr_function([x2], y2)
370+
with pytest.raises(Exception):
371+
fn2(x2_test)

0 commit comments

Comments
 (0)