Skip to content

Commit 1701ed7

Browse files
authored
(Re)implement ndonnx.repeat for all built-in data types (#161)
1 parent 53d8c89 commit 1701ed7

File tree

6 files changed

+109
-9
lines changed

6 files changed

+109
-9
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
Changelog
88
=========
99

10-
0.14.1 (unreleased)
10+
0.15.0 (2025-08-13)
1111
-------------------
1212

13+
**New feature**
14+
15+
- :func:`ndonnx.repeat` is now implemented for all built-in data types.
16+
1317
**Other changes**
1418

1519
- Improve stacklevel presented of various deprecation warnings.

ndonnx/_typed_array/onnx.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -840,11 +840,6 @@ def reshape(self, shape: tuple[int, ...] | TyArrayInt64) -> Self:
840840
var = op.reshape(self._var, shape._var, allowzero=True)
841841
return type(self)(var)
842842

843-
def repeat(
844-
self, repeats: int | TyArrayInt64, /, *, axis: int | None = None
845-
) -> Self:
846-
raise NotImplementedError
847-
848843
def searchsorted(
849844
self,
850845
x2: Self,
@@ -2820,3 +2815,18 @@ def dummy_n_m(a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER) -> TyArrayInt64:
28202815

28212816
arr = type(a)(var_with_propagation)
28222817
return arr
2818+
2819+
2820+
def arange(
2821+
start: TyArrayInt64,
2822+
/,
2823+
stop: TyArrayInt64 | None = None,
2824+
step: int | TyArrayInt64 = 1,
2825+
) -> TyArrayInt64:
2826+
if stop is None:
2827+
stop = start
2828+
start = const(0)
2829+
if isinstance(step, int):
2830+
step = const(step)
2831+
var = op.range(start=start._var, limit=stop._var, delta=step._var)
2832+
return TyArrayInt64(var)

ndonnx/_typed_array/typed_array.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,58 @@ def variance(
457457
def repeat(
458458
self, repeats: int | TyArrayInt64, /, *, axis: int | None = None
459459
) -> Self:
460-
raise _make_type_error("repeat", self.dtype)
460+
# Default implementation of `repeat` that does not depend on the underlying layout
461+
from . import onnx
462+
463+
x = self
464+
if axis is None:
465+
x = x.reshape((-1,))
466+
axis = 0
467+
468+
if axis < 0:
469+
axis += x.ndim
470+
471+
x_shape = x.dynamic_shape
472+
473+
if isinstance(repeats, onnx.TyArrayInt64) and repeats.ndim > 1:
474+
raise ValueError(
475+
f"'repeats' must be 0 or 1 dimensional, but has `{repeats.ndim}` dimensions"
476+
)
477+
478+
if isinstance(repeats, int):
479+
repeats = onnx.const(np.asarray(repeats), onnx.int64)
480+
481+
if repeats.ndim == 0:
482+
repeats = repeats[None]
483+
# Expand by one dimension and use regular broadcasting to repeat
484+
key: list[slice | None] = [slice(None, None)] * (x.ndim + 1)
485+
key[axis + 1] = None
486+
tmp = x[tuple(key)].broadcast_to(
487+
x_shape[: axis + 1].concat([repeats, x_shape[axis + 1 :]], axis=0)
488+
)
489+
# We now have a shape of (..., self.shape[axis], repeats, ...).
490+
# All we need to do is collapse these two axes into one.
491+
out_shape = x_shape.copy()
492+
out_shape[axis] = safe_cast(onnx.TyArrayInt64, out_shape[axis] * repeats)
493+
494+
return tmp.reshape(out_shape)
495+
496+
# Repeats may be of shape (1,)
497+
repeats = repeats.broadcast_to(x_shape[axis][None])
498+
499+
max_rep = repeats.max(keepdims=False)
500+
repeated_with_maxrep = x.repeat(max_rep, axis=axis)
501+
502+
mask2D = onnx.arange(max_rep)[None, :] < repeats[:, None]
503+
mask1D = mask2D.reshape((-1,))
504+
505+
# We can only use __getitem__ with a boolean mask on 1D arrays.
506+
# Otherwise, we have to compute indices and use `take`
507+
if x.ndim == 1:
508+
return repeated_with_maxrep[mask1D]
509+
510+
indices = onnx.arange(mask1D.dynamic_size)[mask1D]
511+
return repeated_with_maxrep.take(indices, axis=axis)
461512

462513
def tile(self, repetitions: tuple[int, ...], /) -> Self:
463514
raise _make_type_error("tile", self.dtype)

skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Missing implementations which should be perfectly doable
22
array_api_tests/test_linalg.py::test_tensordot
3-
array_api_tests/test_manipulation_functions.py::test_repeat
43

54
# Tests that fail due to lossy ORT workarounds, which are tested elsewhere
65
array_api_tests/test_statistical_functions.py::test_sum

tests/test_manipulation_functions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,39 @@ def do(npx):
1515
return npx.concat([a1, a2], axis=None)
1616

1717
np.testing.assert_array_equal(do(np), do(ndx))
18+
19+
20+
@pytest.mark.parametrize(
21+
"x, repeats, axis",
22+
[
23+
([1, 2, 3], 2, None),
24+
([1, 2, 3], np.asarray(2), None),
25+
([1, 2, 3], np.asarray(2), 0),
26+
([1, 2, 3], np.asarray([2]), None),
27+
([1, 2, 3], [1, 0, 2], None),
28+
([1, 2, 3], [1, 0, 2], 0),
29+
([[1, 2, 3]], [1, 0, 2], None),
30+
([[1, 2, 3]], [1, 0, 2], 1),
31+
(np.arange(27).reshape((3, 3, 3)), [1, 0, 2], 0),
32+
(np.arange(27).reshape((3, 3, 3)), [1, 0, 2], 1),
33+
(np.arange(27).reshape((3, 3, 3)), [1, 0, 2], 2),
34+
# zero-size
35+
(np.ones((3, 0, 3)), 2, None),
36+
(np.ones((3, 0, 3)), 2, 0),
37+
(np.ones((3, 0, 3)), 2, 1),
38+
(np.ones((3, 0, 3)), 2, 2),
39+
(np.ones((3, 0, 3)), [2], 2),
40+
(np.ones((3, 0, 3)), [1, 0, 2], 2),
41+
# other data types
42+
(["a", "b", "c"], np.asarray([2]), None),
43+
(np.asarray([1, 2, 3], dtype="datetime64[s]"), np.asarray([2]), None),
44+
],
45+
)
46+
def test_repeat(x, repeats, axis):
47+
def do(npx):
48+
repeats_ = repeats
49+
if isinstance(repeats, list | np.ndarray):
50+
repeats_ = npx.asarray(repeats_)
51+
return npx.repeat(npx.asarray(x), repeats_, axis=axis)
52+
53+
np.testing.assert_array_equal(do(ndx).unwrap_numpy(), do(np))

0 commit comments

Comments
 (0)