@@ -178,31 +178,53 @@ def test_groupby_reduce(
178
178
assert_equal (expected_result , result )
179
179
180
180
181
- def gen_array_by (size , func ):
182
- by = np .ones (size [- 1 ])
183
- rng = np .random .default_rng (12345 )
181
+ def maybe_skip_cupy (array_module , func , engine ):
182
+ if array_module is np :
183
+ return
184
+
185
+ import cupy
186
+
187
+ assert array_module is cupy
188
+
189
+ if engine == "numba" :
190
+ pytest .skip ()
191
+
192
+ if engine == "numpy" and ("prod" in func or "first" in func or "last" in func ):
193
+ pytest .xfail ()
194
+ elif engine == "flox" and not (
195
+ "sum" in func or "mean" in func or "std" in func or "var" in func
196
+ ):
197
+ pytest .xfail ()
198
+
199
+
200
+ def gen_array_by (size , func , array_module ):
201
+ xp = array_module
202
+ by = xp .ones (size [- 1 ])
203
+ rng = xp .random .default_rng (12345 )
184
204
array = rng .random (size )
185
205
if "nan" in func and "nanarg" not in func :
186
- array [[1 , 4 , 5 ], ...] = np .nan
206
+ array [[1 , 4 , 5 ], ...] = xp .nan
187
207
elif "nanarg" in func and len (size ) > 1 :
188
- array [[1 , 4 , 5 ], 1 ] = np .nan
208
+ array [[1 , 4 , 5 ], 1 ] = xp .nan
189
209
if func in ["any" , "all" ]:
190
210
array = array > 0.5
191
211
return array , by
192
212
193
213
194
- @pytest .mark .parametrize ("chunks" , [None , - 1 , 3 , 4 ])
195
214
@pytest .mark .parametrize ("nby" , [1 , 2 , 3 ])
196
215
@pytest .mark .parametrize ("size" , ((12 ,), (12 , 9 )))
197
- @pytest .mark .parametrize ("add_nan_by " , [True , False ])
216
+ @pytest .mark .parametrize ("chunks " , [None , - 1 , 3 , 4 ])
198
217
@pytest .mark .parametrize ("func" , ALL_FUNCS )
199
- def test_groupby_reduce_all (nby , size , chunks , func , add_nan_by , engine ):
218
+ @pytest .mark .parametrize ("add_nan_by" , [True , False ])
219
+ def test_groupby_reduce_all (nby , size , chunks , func , add_nan_by , engine , array_module ):
200
220
if chunks is not None and not has_dask :
201
221
pytest .skip ()
202
222
if "arg" in func and engine == "flox" :
203
223
pytest .skip ()
204
224
205
- array , by = gen_array_by (size , func )
225
+ maybe_skip_cupy (array_module , func , engine )
226
+
227
+ array , by = gen_array_by (size , func , array_module )
206
228
if chunks :
207
229
array = dask .array .from_array (array , chunks = chunks )
208
230
by = (by ,) * nby
@@ -254,10 +276,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
254
276
assert expected .ndim == (array .ndim + nby - 1 )
255
277
expected_groups = tuple (np .array ([idx + 1.0 ]) for idx in range (nby ))
256
278
for actual_group , expect in zip (groups , expected_groups ):
257
- assert_equal (actual_group , expect )
279
+ assert_equal (actual_group , array_module . asarray ( expect ) )
258
280
if "arg" in func :
259
281
assert actual .dtype .kind == "i"
260
- assert_equal (actual , expected , tolerance )
282
+ if chunks is not None :
283
+ assert isinstance (actual ._meta , type (array ._meta ))
284
+ assert_equal (actual , array_module .asarray (expected ), tolerance )
261
285
262
286
if not has_dask or chunks is None :
263
287
continue
@@ -287,6 +311,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
287
311
assert_equal (actual_group , expect , tolerance )
288
312
if "arg" in func :
289
313
assert actual .dtype .kind == "i"
314
+ if chunks is not None :
315
+ assert isinstance (actual ._meta , type (array ._meta ))
290
316
assert_equal (actual , expected , tolerance )
291
317
292
318
@@ -313,18 +339,18 @@ def test_arg_reduction_dtype_is_int(size, func):
313
339
assert actual .dtype .kind == "i"
314
340
315
341
316
- def test_groupby_reduce_count ():
317
- array = np .array ([0 , 0 , np .nan , np .nan , np .nan , 1 , 1 ])
318
- labels = np .array (["a" , "b" , "b" , "b" , "c" , "c" , "c" ])
342
+ def test_groupby_reduce_count (array_module ):
343
+ array = array_module .array ([0 , 0 , np .nan , np .nan , np .nan , 1 , 1 ])
344
+ labels = array_module .array (["a" , "b" , "b" , "b" , "c" , "c" , "c" ])
319
345
result , _ = groupby_reduce (array , labels , func = "count" )
320
346
assert_equal (result , np .array ([1 , 1 , 2 ], dtype = np .intp ))
321
347
322
348
323
- def test_func_is_aggregation ():
349
+ def test_func_is_aggregation (array_module ):
324
350
from flox .aggregations import mean
325
351
326
- array = np .array ([0 , 0 , np .nan , np .nan , np .nan , 1 , 1 ])
327
- labels = np .array (["a" , "b" , "b" , "b" , "c" , "c" , "c" ])
352
+ array = array_module .array ([0 , 0 , np .nan , np .nan , np .nan , 1 , 1 ])
353
+ labels = array_module .array (["a" , "b" , "b" , "b" , "c" , "c" , "c" ])
328
354
expected , _ = groupby_reduce (array , labels , func = "mean" )
329
355
actual , _ = groupby_reduce (array , labels , func = mean )
330
356
assert_equal (actual , expected )
0 commit comments