@@ -130,13 +130,18 @@ def non_materializable4(x: Array) -> Array:
130
130
return non_materializable (x )
131
131
132
132
133
+ def non_materializable5 (x : Array ) -> Array :
134
+ return non_materializable (x )
135
+
136
+
133
137
lazy_xp_function (good_lazy )
134
138
# Works on JAX and Dask
135
139
lazy_xp_function (non_materializable2 , jax_jit = False , allow_dask_compute = 2 )
140
+ lazy_xp_function (non_materializable3 , jax_jit = False , allow_dask_compute = True )
136
141
# Works on JAX, but not Dask
137
- lazy_xp_function (non_materializable3 , jax_jit = False , allow_dask_compute = 1 )
142
+ lazy_xp_function (non_materializable4 , jax_jit = False , allow_dask_compute = 1 )
138
143
# Works neither on Dask nor JAX
139
- lazy_xp_function (non_materializable4 )
144
+ lazy_xp_function (non_materializable5 )
140
145
141
146
142
147
def test_lazy_xp_function (xp : ModuleType ):
@@ -147,29 +152,30 @@ def test_lazy_xp_function(xp: ModuleType):
147
152
xp_assert_equal (non_materializable (x ), xp .asarray ([1.0 , 2.0 ]))
148
153
# Wrapping explicitly disabled
149
154
xp_assert_equal (non_materializable2 (x ), xp .asarray ([1.0 , 2.0 ]))
155
+ xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
150
156
151
157
if is_jax_namespace (xp ):
152
- xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
158
+ xp_assert_equal (non_materializable4 (x ), xp .asarray ([1.0 , 2.0 ]))
153
159
with pytest .raises (
154
160
TypeError , match = "Attempted boolean conversion of traced array"
155
161
):
156
- _ = non_materializable4 (x ) # Wrapped
162
+ _ = non_materializable5 (x ) # Wrapped
157
163
158
164
elif is_dask_namespace (xp ):
159
165
with pytest .raises (
160
166
AssertionError ,
161
167
match = r"dask\.compute.* 2 times, but only up to 1 calls are allowed" ,
162
168
):
163
- _ = non_materializable3 (x )
169
+ _ = non_materializable4 (x )
164
170
with pytest .raises (
165
171
AssertionError ,
166
172
match = r"dask\.compute.* 1 times, but no calls are allowed" ,
167
173
):
168
- _ = non_materializable4 (x )
174
+ _ = non_materializable5 (x )
169
175
170
176
else :
171
- xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
172
177
xp_assert_equal (non_materializable4 (x ), xp .asarray ([1.0 , 2.0 ]))
178
+ xp_assert_equal (non_materializable5 (x ), xp .asarray ([1.0 , 2.0 ]))
173
179
174
180
175
181
def static_params (x : Array , n : int , flag : bool = False ) -> Array :
0 commit comments