Skip to content

Commit 4425d14

Browse files
authored
ENH: allow_dask_compute=True instead of 999 (#283)
1 parent 9376399 commit 4425d14

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

src/array_api_extra/testing.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def override(func: object) -> object:
3939
def lazy_xp_function( # type: ignore[explicit-any]
4040
func: Callable[..., Any],
4141
*,
42-
allow_dask_compute: int = 0,
42+
allow_dask_compute: bool | int = False,
4343
jax_jit: bool = True,
4444
static_argnums: int | Sequence[int] | None = None,
4545
static_argnames: str | Iterable[str] | None = None,
@@ -59,9 +59,10 @@ def lazy_xp_function( # type: ignore[explicit-any]
5959
----------
6060
func : callable
6161
Function to be tested.
62-
allow_dask_compute : int, optional
63-
Number of times `func` is allowed to internally materialize the Dask graph. This
64-
is typically triggered by ``bool()``, ``float()``, or ``np.asarray()``.
62+
allow_dask_compute : bool | int, optional
63+
Whether `func` is allowed to internally materialize the Dask graph, or maximum
64+
number of times it is allowed to do so. This is typically triggered by
65+
``bool()``, ``float()``, or ``np.asarray()``.
6566
6667
Set to 1 if you are aware that `func` converts the input parameters to NumPy and
6768
want to let it do so at least for the time being, knowing that it is going to be
@@ -75,7 +76,10 @@ def lazy_xp_function( # type: ignore[explicit-any]
7576
a test function that invokes `func` multiple times should still work with this
7677
parameter set to 1.
7778
78-
Default: 0, meaning that `func` must be fully lazy and never materialize the
79+
Set to True to allow `func` to materialize the graph an unlimited number
80+
of times.
81+
82+
Default: False, meaning that `func` must be fully lazy and never materialize the
7983
graph.
8084
jax_jit : bool, optional
8185
Set to True to replace `func` with ``jax.jit(func)`` after calling the
@@ -235,6 +239,10 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
235239
if is_dask_namespace(xp):
236240
for mod, name, func, tags in iter_tagged():
237241
n = tags["allow_dask_compute"]
242+
if n is True:
243+
n = 1_000_000
244+
elif n is False:
245+
n = 0
238246
wrapped = _dask_wrap(func, n)
239247
monkeypatch.setattr(mod, name, wrapped)
240248

tests/test_testing.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,18 @@ def non_materializable4(x: Array) -> Array:
130130
return non_materializable(x)
131131

132132

133+
def non_materializable5(x: Array) -> Array:
134+
return non_materializable(x)
135+
136+
133137
lazy_xp_function(good_lazy)
134138
# Works on JAX and Dask
135139
lazy_xp_function(non_materializable2, jax_jit=False, allow_dask_compute=2)
140+
lazy_xp_function(non_materializable3, jax_jit=False, allow_dask_compute=True)
136141
# Works on JAX, but not Dask
137-
lazy_xp_function(non_materializable3, jax_jit=False, allow_dask_compute=1)
142+
lazy_xp_function(non_materializable4, jax_jit=False, allow_dask_compute=1)
138143
# Works neither on Dask nor JAX
139-
lazy_xp_function(non_materializable4)
144+
lazy_xp_function(non_materializable5)
140145

141146

142147
def test_lazy_xp_function(xp: ModuleType):
@@ -147,29 +152,30 @@ def test_lazy_xp_function(xp: ModuleType):
147152
xp_assert_equal(non_materializable(x), xp.asarray([1.0, 2.0]))
148153
# Wrapping explicitly disabled
149154
xp_assert_equal(non_materializable2(x), xp.asarray([1.0, 2.0]))
155+
xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0]))
150156

151157
if is_jax_namespace(xp):
152-
xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0]))
158+
xp_assert_equal(non_materializable4(x), xp.asarray([1.0, 2.0]))
153159
with pytest.raises(
154160
TypeError, match="Attempted boolean conversion of traced array"
155161
):
156-
_ = non_materializable4(x) # Wrapped
162+
_ = non_materializable5(x) # Wrapped
157163

158164
elif is_dask_namespace(xp):
159165
with pytest.raises(
160166
AssertionError,
161167
match=r"dask\.compute.* 2 times, but only up to 1 calls are allowed",
162168
):
163-
_ = non_materializable3(x)
169+
_ = non_materializable4(x)
164170
with pytest.raises(
165171
AssertionError,
166172
match=r"dask\.compute.* 1 times, but no calls are allowed",
167173
):
168-
_ = non_materializable4(x)
174+
_ = non_materializable5(x)
169175

170176
else:
171-
xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0]))
172177
xp_assert_equal(non_materializable4(x), xp.asarray([1.0, 2.0]))
178+
xp_assert_equal(non_materializable5(x), xp.asarray([1.0, 2.0]))
173179

174180

175181
def static_params(x: Array, n: int, flag: bool = False) -> Array:

0 commit comments

Comments
 (0)