Skip to content

Commit c2b83fd

Browse files
committed
Begin testing
1 parent 9d35bb2 commit c2b83fd

File tree

5 files changed

+87
-21
lines changed

5 files changed

+87
-21
lines changed

flox/aggregate_flox.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def _prepare_for_flox(group_idx, array):
1414
if issorted:
1515
ordered_array = array
1616
else:
17-
perm = group_idx.argsort(kind="stable")
17+
kind = "stable" if isinstance(group_idx, np.ndarray) else None
18+
19+
perm = np.argsort(group_idx, kind=kind)
1820
group_idx = group_idx[..., perm]
1921
ordered_array = array[..., perm]
2022
return group_idx, ordered_array

flox/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ def factorize_(
570570
else:
571571
assert sort
572572
groups, idx = np.unique(flat, return_inverse=True)
573+
idx[np.isnan(flat)] = -1
574+
groups = groups[~np.isnan(groups)]
573575

574576
found_groups.append(groups)
575577
factorized.append(idx.reshape(groupvar.shape))
@@ -1261,7 +1263,7 @@ def subset_to_blocks(
12611263
layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
12621264
graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])
12631265

1264-
return dask.array.Array(graph, name, chunks, meta=array)
1266+
return dask.array.Array(graph, name, chunks, meta=array._meta)
12651267

12661268

12671269
def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]:
@@ -1494,6 +1496,7 @@ def dask_groupby_agg(
14941496
reduced,
14951497
inds,
14961498
adjust_chunks=dict(zip(out_inds, output_chunks)),
1499+
meta=array._meta,
14971500
dtype=agg.dtype["final"],
14981501
key=agg.name,
14991502
name=f"{name}-{token}",

tests/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
except ImportError:
2525
xr_types = () # type: ignore
2626

27+
try:
28+
import cupy as cp
29+
30+
cp_types = (cp.ndarray,)
31+
except ImportError:
32+
cp_types = () # type: ignore
33+
2734

2835
def _importorskip(modname, minversion=None):
2936
try:
@@ -80,6 +87,15 @@ def raise_if_dask_computes(max_computes=0):
8087
return dask.config.set(scheduler=scheduler)
8188

8289

90+
def to_numpy(a):
91+
a_np = a
92+
if isinstance(a_np, dask_array_type):
93+
a_np = a_np.compute()
94+
if isinstance(a_np, cp_types):
95+
a_np = a_np.get()
96+
return a_np
97+
98+
8399
def assert_equal(a, b, tolerance=None):
84100
__tracebackhide__ = True
85101

@@ -102,16 +118,20 @@ def assert_equal(a, b, tolerance=None):
102118
else:
103119
tolerance = {}
104120

105-
if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
121+
if has_dask and (isinstance(a, dask_array_type) or isinstance(b, dask_array_type)):
106122
# sometimes it's nice to see values and shapes
107123
# rather than being dropped into some file in dask
108-
np.testing.assert_allclose(a, b, **tolerance)
124+
np.testing.assert_allclose(to_numpy(a), to_numpy(b), **tolerance)
109125
# does some validation of the dask graph
110126
da.utils.assert_eq(a, b, equal_nan=True)
111127
else:
112128
if a.dtype != b.dtype:
113129
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")
114130

131+
if isinstance(a, cp_types):
132+
a = a.get()
133+
if isinstance(b, cp_types):
134+
b = b.get()
115135
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)
116136

117137

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,18 @@ def engine(request):
99
except ImportError:
1010
pytest.xfail()
1111
return request.param
12+
13+
14+
@pytest.fixture(scope="module", params=["numpy", "cupy"])
15+
def array_module(request):
16+
if request.param == "cupy":
17+
try:
18+
import cupy # noqa
19+
20+
return cupy
21+
except ImportError:
22+
pytest.xfail()
23+
elif request.param == "numpy":
24+
import numpy
25+
26+
return numpy

tests/test_core.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -178,31 +178,53 @@ def test_groupby_reduce(
178178
assert_equal(expected_result, result)
179179

180180

181-
def gen_array_by(size, func):
182-
by = np.ones(size[-1])
183-
rng = np.random.default_rng(12345)
181+
def maybe_skip_cupy(array_module, func, engine):
182+
if array_module is np:
183+
return
184+
185+
import cupy
186+
187+
assert array_module is cupy
188+
189+
if engine == "numba":
190+
pytest.skip()
191+
192+
if engine == "numpy" and ("prod" in func or "first" in func or "last" in func):
193+
pytest.xfail()
194+
elif engine == "flox" and not (
195+
"sum" in func or "mean" in func or "std" in func or "var" in func
196+
):
197+
pytest.xfail()
198+
199+
200+
def gen_array_by(size, func, array_module):
201+
xp = array_module
202+
by = xp.ones(size[-1])
203+
rng = xp.random.default_rng(12345)
184204
array = rng.random(size)
185205
if "nan" in func and "nanarg" not in func:
186-
array[[1, 4, 5], ...] = np.nan
206+
array[[1, 4, 5], ...] = xp.nan
187207
elif "nanarg" in func and len(size) > 1:
188-
array[[1, 4, 5], 1] = np.nan
208+
array[[1, 4, 5], 1] = xp.nan
189209
if func in ["any", "all"]:
190210
array = array > 0.5
191211
return array, by
192212

193213

194-
@pytest.mark.parametrize("chunks", [None, -1, 3, 4])
195214
@pytest.mark.parametrize("nby", [1, 2, 3])
196215
@pytest.mark.parametrize("size", ((12,), (12, 9)))
197-
@pytest.mark.parametrize("add_nan_by", [True, False])
216+
@pytest.mark.parametrize("chunks", [None, -1, 3, 4])
198217
@pytest.mark.parametrize("func", ALL_FUNCS)
199-
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
218+
@pytest.mark.parametrize("add_nan_by", [True, False])
219+
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine, array_module):
200220
if chunks is not None and not has_dask:
201221
pytest.skip()
202222
if "arg" in func and engine == "flox":
203223
pytest.skip()
204224

205-
array, by = gen_array_by(size, func)
225+
maybe_skip_cupy(array_module, func, engine)
226+
227+
array, by = gen_array_by(size, func, array_module)
206228
if chunks:
207229
array = dask.array.from_array(array, chunks=chunks)
208230
by = (by,) * nby
@@ -254,10 +276,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
254276
assert expected.ndim == (array.ndim + nby - 1)
255277
expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby))
256278
for actual_group, expect in zip(groups, expected_groups):
257-
assert_equal(actual_group, expect)
279+
assert_equal(actual_group, array_module.asarray(expect))
258280
if "arg" in func:
259281
assert actual.dtype.kind == "i"
260-
assert_equal(actual, expected, tolerance)
282+
if chunks is not None:
283+
assert isinstance(actual._meta, type(array._meta))
284+
assert_equal(actual, array_module.asarray(expected), tolerance)
261285

262286
if not has_dask or chunks is None:
263287
continue
@@ -287,6 +311,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
287311
assert_equal(actual_group, expect, tolerance)
288312
if "arg" in func:
289313
assert actual.dtype.kind == "i"
314+
if chunks is not None:
315+
assert isinstance(actual._meta, type(array._meta))
290316
assert_equal(actual, expected, tolerance)
291317

292318

@@ -313,18 +339,18 @@ def test_arg_reduction_dtype_is_int(size, func):
313339
assert actual.dtype.kind == "i"
314340

315341

316-
def test_groupby_reduce_count():
317-
array = np.array([0, 0, np.nan, np.nan, np.nan, 1, 1])
318-
labels = np.array(["a", "b", "b", "b", "c", "c", "c"])
342+
def test_groupby_reduce_count(array_module):
343+
array = array_module.array([0, 0, np.nan, np.nan, np.nan, 1, 1])
344+
labels = array_module.array(["a", "b", "b", "b", "c", "c", "c"])
319345
result, _ = groupby_reduce(array, labels, func="count")
320346
assert_equal(result, np.array([1, 1, 2], dtype=np.intp))
321347

322348

323-
def test_func_is_aggregation():
349+
def test_func_is_aggregation(array_module):
324350
from flox.aggregations import mean
325351

326-
array = np.array([0, 0, np.nan, np.nan, np.nan, 1, 1])
327-
labels = np.array(["a", "b", "b", "b", "c", "c", "c"])
352+
array = array_module.array([0, 0, np.nan, np.nan, np.nan, 1, 1])
353+
labels = array_module.array(["a", "b", "b", "b", "c", "c", "c"])
328354
expected, _ = groupby_reduce(array, labels, func="mean")
329355
actual, _ = groupby_reduce(array, labels, func=mean)
330356
assert_equal(actual, expected)

0 commit comments

Comments
 (0)