Skip to content

Commit d50db11

Browse files
committed
Provide static shape in output of Split
1 parent 5308ddd commit d50db11

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

pytensor/tensor/basic.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,8 +2201,28 @@ def make_node(self, x, axis, splits):
22012201
raise TypeError("`axis` parameter must be an integer scalar")
22022202

22032203
inputs = [x, axis, splits]
2204-
out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim)
2205-
outputs = [out_type() for i in range(self.len_splits)]
2204+
2205+
x_dtype = x.type.dtype
2206+
if isinstance(axis, Constant):
2207+
# In this case we can preserve more static shape info
2208+
static_axis = axis.data.item()
2209+
outputs = []
2210+
x_static_shape = list(x.type.shape)
2211+
for i in range(self.len_splits):
2212+
try:
2213+
static_split_size = int(get_scalar_constant_value(splits[i]))
2214+
except NotScalarConstantError:
2215+
static_split_size = None
2216+
except IndexError:
2217+
raise ValueError("Number of splits is larger than splits size")
2218+
static_out_shape = x_static_shape.copy()
2219+
static_out_shape[static_axis] = static_split_size
2220+
outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype))
2221+
else:
2222+
outputs = [
2223+
tensor(shape=(None,) * x.type.ndim, dtype=x_dtype)
2224+
for i in range(self.len_splits)
2225+
]
22062226

22072227
return Apply(self, inputs, outputs)
22082228

tests/link/jax/test_tensor_basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,14 @@ def test_runtime_errors(self):
150150
):
151151
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
152152

153-
a_splits = ptb.split(a, splits_size=[2, 4], n_splits=3, axis=0)
154-
fn = pytensor.function([a], a_splits, mode="JAX")
153+
# This check is triggered at compile time if splits_size has incompatible static length
154+
splits_size = vector("splits_size", shape=(None,), dtype=int)
155+
a_splits = ptb.split(a, splits_size=splits_size, n_splits=3, axis=0)
156+
fn = pytensor.function([a, splits_size], a_splits, mode="JAX")
155157
with pytest.raises(
156158
ValueError, match="Length of splits is not equal to n_splits"
157159
):
158-
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
160+
fn(np.zeros((6, 4), dtype=pytensor.config.floatX), [2, 2])
159161

160162
a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=0)
161163
fn = pytensor.function([a], a_splits, mode="JAX")

0 commit comments

Comments
 (0)