Skip to content

Commit d9b1033

Browse files
committed
Correct fast_path for arbitrary axis
1 parent de21fb5 commit d9b1033

3 files changed

Lines changed: 110 additions & 86 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from .utils import (
5252
NUMPY_GE_2_0,
5353
_get_chunk_operands,
54+
_sliced_chunk_iter,
5455
check_smaller_shape,
5556
compute_smaller_slice,
5657
constructors,
@@ -1084,10 +1085,14 @@ def get_chunk(arr, info, nchunk):
10841085

10851086
async def async_read_chunks(arrs, info, queue):
10861087
loop = asyncio.get_event_loop()
1087-
nchunks = arrs[0].schunk.nchunks
1088-
1088+
shape, chunks_ = arrs[0].shape, arrs[0].chunks
10891089
with concurrent.futures.ThreadPoolExecutor() as executor:
1090-
for nchunk in range(nchunks):
1090+
my_chunk_iter = range(arrs[0].schunk.nchunks)
1091+
if len(info) == 5:
1092+
if info[-1] is not None:
1093+
my_chunk_iter = _sliced_chunk_iter(chunks_, (), shape, axis=info[-1], nchunk=True)
1094+
info = info[:4]
1095+
for i, nchunk in enumerate(my_chunk_iter):
10911096
futures = [
10921097
(index, loop.run_in_executor(executor, get_chunk, arr, info, nchunk))
10931098
for index, arr in enumerate(arrs)
@@ -1100,7 +1105,7 @@ async def async_read_chunks(arrs, info, queue):
11001105
print(f"Exception occurred: {chunk}")
11011106
raise chunk
11021107
chunks_sorted.append(chunk)
1103-
queue.put((nchunk, chunks_sorted)) # use non-async queue.put()
1108+
queue.put((i, chunks_sorted)) # use non-async queue.put()
11041109

11051110
queue.put(None) # signal the end of the chunks
11061111

@@ -1137,7 +1142,7 @@ def read_nchunk(arrs, info):
11371142

11381143

11391144
def fill_chunk_operands(
1140-
operands, slice_, chunks_, full_chunk, aligned, nchunk, iter_disk, chunk_operands, reduc=False
1145+
operands, slice_, chunks_, full_chunk, aligned, nchunk, iter_disk, chunk_operands, reduc=False, axis=None
11411146
):
11421147
"""Retrieve the chunk operands for evaluating an expression.
11431148
@@ -1150,20 +1155,23 @@ def fill_chunk_operands(
11501155
low_mem = os.environ.get("BLOSC_LOW_MEM", False)
11511156
# This method is only useful when all operands are NDArray and shows better
11521157
# performance only when at least one of them is persisted on disk
1153-
if nchunk == 0:
1158+
if iter_chunks is None:
11541159
# Initialize the iterator for reading the chunks
11551160
# Take any operand (all should have the same shape and chunks)
11561161
key, arr = next(iter(operands.items()))
11571162
chunks_idx, _ = get_chunks_idx(arr.shape, arr.chunks)
1158-
info = (reduc, aligned[key], low_mem, chunks_idx)
1163+
info = (reduc, aligned[key], low_mem, chunks_idx, axis)
11591164
iter_chunks = read_nchunk(list(operands.values()), info)
11601165
# Run the asynchronous file reading function from a synchronous context
11611166
chunks = next(iter_chunks)
11621167

11631168
for i, (key, value) in enumerate(operands.items()):
11641169
# Chunks are already decompressed, so we can use them directly
11651170
if not low_mem:
1166-
chunk_operands[key] = chunks[i]
1171+
if full_chunk:
1172+
chunk_operands[key] = chunks[i]
1173+
else:
1174+
chunk_operands[key] = value[slice_]
11671175
continue
11681176
# Otherwise, we need to decompress them
11691177
if aligned[key]:
@@ -1568,10 +1576,12 @@ def slices_eval( # noqa: C901
15681576
intersecting_chunks = get_intersecting_chunks(
15691577
_slice, shape, chunks
15701578
) # if _slice is (), returns all chunks
1579+
ratio = np.ceil(np.asarray(shape) / np.asarray(chunks)).astype(np.int64)
15711580

1572-
for nchunk, chunk_slice in enumerate(intersecting_chunks):
1581+
for chunk_slice in intersecting_chunks:
15731582
# Check whether current cslice intersects with _slice
15741583
cslice = chunk_slice.raw
1584+
nchunk = builtins.sum([c.start // chunks[i] * np.prod(ratio[i + 1 :]) for i, c in enumerate(cslice)])
15751585
if cslice != () and _slice != ():
15761586
# get intersection of chunk and target
15771587
cslice = step_handler(cslice, _slice)
@@ -1889,9 +1899,9 @@ def reduce_slices( # noqa: C901
18891899
if np.any(mask_slice):
18901900
add_idx = np.cumsum(mask_slice)
18911901
axis = tuple(a + add_idx[a] for a in axis) # axis now refers to new shape with dummy dims
1892-
if reduce_args["axis"] is not None:
1893-
# conserve as integer if was not tuple originally
1894-
reduce_args["axis"] = axis[0] if np.isscalar(reduce_args["axis"]) else axis
1902+
if reduce_args["axis"] is not None:
1903+
# conserve as integer if was not tuple originally
1904+
reduce_args["axis"] = axis[0] if np.isscalar(reduce_args["axis"]) else axis
18951905
if reduce_op in {ReduceOp.CUMULATIVE_SUM, ReduceOp.CUMULATIVE_PROD}:
18961906
reduced_shape = (np.prod(shape_slice),) if reduce_args["axis"] is None else shape_slice
18971907
# if reduce_args["axis"] is None, have to have 1D input array; otherwise, ensure positive scalar
@@ -2057,18 +2067,17 @@ def reduce_slices( # noqa: C901
20572067
# Iterate over the operands and get the chunks
20582068
chunk_operands = {}
20592069
# Check which chunks intersect with _slice
2060-
# if chunks has 0 we loop once but fast path is false as gives error (schunk has no chunks)
2061-
if (
2062-
np.isscalar(reduce_args["axis"]) and not fast_path
2063-
): # iterate over chunks incrementing along reduction axis
2070+
if np.isscalar(reduce_args["axis"]): # iterate over chunks incrementing along reduction axis
20642071
intersecting_chunks = get_intersecting_chunks(_slice, shape, chunks, axis=reduce_args["axis"])
20652072
else: # iterate over chunks incrementing along last axis
20662073
intersecting_chunks = get_intersecting_chunks(_slice, shape, chunks)
20672074
out_init = False
20682075
res_out_init = False
2076+
ratio = np.ceil(np.asarray(shape) / np.asarray(chunks)).astype(np.int64)
20692077

2070-
for nchunk, chunk_slice in enumerate(intersecting_chunks):
2078+
for chunk_slice in intersecting_chunks:
20712079
cslice = chunk_slice.raw
2080+
nchunk = builtins.sum([c.start // chunks[i] * np.prod(ratio[i + 1 :]) for i, c in enumerate(cslice)])
20722081
# Check whether current cslice intersects with _slice
20732082
if cslice != () and _slice != ():
20742083
# get intersection of chunk and target
@@ -2077,6 +2086,8 @@ def reduce_slices( # noqa: C901
20772086
starts = [s.start if s.start is not None else 0 for s in cslice]
20782087
unit_steps = np.all([s.step == 1 for s in cslice])
20792088
cslice_shape = tuple(s.stop - s.start for s in cslice)
2089+
# get local index of part of out that is to be updated
2090+
cslice_subidx = ndindex.ndindex(cslice).as_subindex(_slice).raw # if _slice is (), just gives cslice
20802091
if _slice == () and fast_path and unit_steps:
20812092
# Fast path
20822093
full_chunk = cslice_shape == chunks
@@ -2090,12 +2101,11 @@ def reduce_slices( # noqa: C901
20902101
iter_disk,
20912102
chunk_operands,
20922103
reduc=True,
2104+
axis=reduce_args["axis"] if np.isscalar(reduce_args["axis"]) else None,
20932105
)
20942106
else:
20952107
_get_chunk_operands(operands, cslice, chunk_operands, shape)
20962108

2097-
# get local index of part of out that is to be updated
2098-
cslice_subidx = ndindex.ndindex(cslice).as_subindex(_slice).raw # if _slice is (), just gives cslice
20992109
if reduce_op in {ReduceOp.CUMULATIVE_PROD, ReduceOp.CUMULATIVE_SUM}:
21002110
reduced_slice = (
21012111
tuple(
@@ -3160,22 +3170,7 @@ def find_args(expr):
31603170

31613171
return value, expression[idx:idx2]
31623172

3163-
def _compute_expr(self, item, kwargs): # noqa : C901
3164-
# ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp)
3165-
# that are implemented in python-blosc2 not in numexpr
3166-
global safe_blosc2_globals
3167-
if len(safe_blosc2_globals) == 0:
3168-
# First eval call, fill blosc2_safe_globals for ne_evaluate
3169-
safe_blosc2_globals = {"blosc2": blosc2}
3170-
# Add all first-level blosc2 functions
3171-
safe_blosc2_globals.update(
3172-
{
3173-
name: getattr(blosc2, name)
3174-
for name in dir(blosc2)
3175-
if callable(getattr(blosc2, name)) and not name.startswith("_")
3176-
}
3177-
)
3178-
3173+
def _compute_expr(self, item, kwargs):
31793174
if any(method in self.expression for method in eager_funcs):
31803175
# We have reductions in the expression (probably coming from a string lazyexpr)
31813176
# Also includes slice

src/blosc2/utils.py

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -769,65 +769,81 @@ def _get_local_slice(prior_selection, post_selection, chunk_bounds):
769769
return locbegin, locend
770770

771771

772-
def get_intersecting_chunks(idx, shape, chunks, axis=None):
773-
if 0 in chunks: # chunk is whole array so just return full tuple to do loop once
774-
return (ndindex.ndindex(...).expand(shape),)
775-
chunk_size = ndindex.ChunkSize(chunks)
776-
if axis is None:
777-
return chunk_size.as_subchunks(idx, shape) # if _slice is (), returns all chunks
778-
779-
def return_my_it(chunk_size, idx, shape, axis):
780-
# special algorithm to iterate over axis first (adapted from ndindex source)
781-
shape = ndindex.shapetools.asshape(shape)
772+
def _sliced_chunk_iter(chunks, idx, shape, axis=None, nchunk=False):
773+
"""
774+
If nchunk is True, retrun at iterator over the number of the chunk.
775+
"""
776+
ratio = np.ceil(np.asarray(shape) / np.asarray(chunks)).astype(np.int64)
777+
idx = ndindex.ndindex(idx).expand(shape)
778+
if axis is not None:
779+
idx = tuple(a for i, a in enumerate(idx.args) if i != axis) + (idx.args[axis],)
780+
chunks_ = tuple(a for i, a in enumerate(chunks) if i != axis) + (chunks[axis],)
781+
else:
782+
chunks_ = chunks
783+
idx_iter = iter(idx) # iterate over tuple of slices in order
784+
chunk_iter = iter(chunks_) # iterate over chunk_shape in order
785+
786+
iters = []
787+
while True:
788+
try:
789+
i = next(idx_iter) # slice along axis
790+
n = next(chunk_iter) # chunklen along dimension
791+
except StopIteration:
792+
break
793+
if not isinstance(i, ndindex.Slice):
794+
raise ValueError("Only slices may be used with axis arg")
795+
796+
def _slice_iter(s, n):
797+
a, N, m = s.args
798+
if m > n:
799+
yield from ((a + k * m) // n for k in range(ceiling(N - a, m)))
800+
else:
801+
yield from range(a // n, ceiling(N, n))
782802

783-
idx = ndindex.ndindex(idx).expand(shape)
803+
iters.append(_slice_iter(i, n))
784804

785-
iters = []
786-
idx_args = tuple(a for i, a in enumerate(idx.args) if i != axis) + (idx.args[axis],)
787-
idx_args = iter(idx_args) # iterate over tuple of slices in order
788-
self_ = tuple(a for i, a in enumerate(chunk_size) if i != axis) + (chunk_size[axis],)
789-
self_ = iter(self_) # iterate over chunk_shape in order
790-
while True:
791-
try:
792-
i = next(idx_args) # slice along axis
793-
n = next(self_) # chunklen along dimension
794-
except StopIteration:
795-
break
796-
if not isinstance(i, ndindex.Slice):
797-
raise ValueError("Only slices may be used with axis arg")
798-
799-
def _slice_iter(s, n):
800-
a, N, m = s.args
801-
if m > n:
802-
yield from ((a + k * m) // n for k in range(ceiling(N - a, m)))
803-
else:
804-
yield from range(a // n, ceiling(N, n))
805-
806-
iters.append(_slice_iter(i, n))
807-
808-
def _indices(iters):
809-
my_list = [ndindex.Slice(None, None)] * len(chunk_size)
810-
for p in product(*iters):
811-
# p increments over arg axis first before other axes
812-
# p = (...., -1, axis)
805+
def _indices(iters):
806+
my_list = [ndindex.Slice(None, None)] * len(chunks)
807+
for p in product(*iters):
808+
# p increments over arg axis first before other axes
809+
# p = (...., -1, axis)
810+
if axis is None:
811+
my_list = [
812+
ndindex.Slice(cs * ci, min(cs * (ci + 1), n), 1)
813+
for n, cs, ci in zip(shape, chunks, p, strict=True)
814+
]
815+
else:
813816
my_list[:axis] = [
814817
ndindex.Slice(cs * ci, min(cs * (ci + 1), n), 1)
815-
for n, cs, ci in zip(shape[:axis], chunk_size[:axis], p[:axis], strict=True)
818+
for n, cs, ci in zip(shape[:axis], chunks[:axis], p[:axis], strict=True)
816819
]
817-
n, cs, ci = shape[-1], chunk_size[-1], p[-1]
820+
n, cs, ci = shape[axis], chunks[axis], p[-1]
818821
my_list[axis] = ndindex.Slice(cs * ci, min(cs * (ci + 1), n), 1)
819822
my_list[axis + 1 :] = [
820823
ndindex.Slice(cs * ci, min(cs * (ci + 1), n), 1)
821-
for n, cs, ci in zip(shape[axis:-1], chunk_size[axis:-1], p[axis:-1], strict=True)
824+
for n, cs, ci in zip(shape[axis + 1 :], chunks[axis + 1 :], p[axis:-1], strict=True)
822825
]
826+
if nchunk:
827+
yield builtins.sum(
828+
[c.start // chunks[i] * np.prod(ratio[i + 1 :]) for i, c in enumerate(my_list)]
829+
)
830+
else:
823831
yield ndindex.Tuple(*my_list)
824832

825-
for c in _indices(iters):
826-
# Empty indices should be impossible by the construction of the
827-
# iterators above.
828-
yield from c
833+
yield from _indices(iters)
834+
835+
836+
def get_intersecting_chunks(idx, shape, chunks, axis=None):
837+
if len(chunks) != len(shape):
838+
raise ValueError("chunks must be same length as shape!")
839+
if 0 in chunks: # chunk is whole array so just return full tuple to do loop once
840+
return (ndindex.ndindex(...).expand(shape),), range(0)
841+
chunk_size = ndindex.ChunkSize(chunks)
842+
if axis is None:
843+
return chunk_size.as_subchunks(idx, shape) # if _slice is (), returns all chunks
829844

830-
return return_my_it(chunk_size, idx, shape, axis)
845+
# special algorithm to iterate over axis first (adapted from ndindex source)
846+
return _sliced_chunk_iter(chunks, idx, shape, axis)
831847

832848

833849
def get_chunks_idx(shape, chunks):

tests/ndarray/test_reductions.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,15 @@ def test_reduce_slice(reduce_op):
373373
@pytest.mark.parametrize(
374374
("chunks", "blocks"),
375375
[
376+
((10, 50, 70), (10, 25, 50)),
376377
((20, 50, 100), (10, 50, 100)),
377-
((10, 25, 70), (10, 25, 50)),
378378
((10, 50, 100), (6, 25, 75)),
379379
((15, 30, 75), (7, 20, 50)),
380-
((20, 50, 100), (10, 50, 60)),
380+
((1, 50, 100), (1, 50, 60)),
381381
],
382382
)
383383
@pytest.mark.parametrize("disk", [True, False])
384-
@pytest.mark.parametrize("fill_value", [0, 1, 0.32])
384+
@pytest.mark.parametrize("fill_value", [1, 0, 0.32])
385385
@pytest.mark.parametrize(
386386
"reduce_op",
387387
[
@@ -400,7 +400,7 @@ def test_reduce_slice(reduce_op):
400400
"cumulative_prod",
401401
],
402402
)
403-
@pytest.mark.parametrize("axis", [0, 1, None])
403+
@pytest.mark.parametrize("axis", [None, 0, 1])
404404
def test_fast_path(chunks, blocks, disk, fill_value, reduce_op, axis):
405405
shape = (20, 50, 100)
406406
urlpath = "a1.b2nd" if disk else None
@@ -420,6 +420,19 @@ def test_fast_path(chunks, blocks, disk, fill_value, reduce_op, axis):
420420
res = getattr(a, reduce_op)(axis=axis)
421421
assert np.allclose(res, nres)
422422

423+
# Try with a slice
424+
b = blosc2.arange(0, np.prod(shape), blocks=blocks, chunks=chunks, shape=shape)
425+
nb = b[:]
426+
slice_ = (slice(10, 20),)
427+
if reduce_op in {"cumulative_sum", "cumulative_prod"}:
428+
axis = 0 if axis is None else axis
429+
oploc = "npcumsum" if reduce_op == "cumulative_sum" else "npcumprod"
430+
nres = eval(f"{oploc}((na + nb)[{slice_}], axis={axis})")
431+
else:
432+
nres = getattr((na + nb)[slice_], reduce_op)(axis=axis)
433+
res = getattr(a + b, reduce_op)(axis=axis, item=slice_)
434+
assert np.allclose(res, nres)
435+
423436

424437
@pytest.mark.parametrize("disk", [True, False])
425438
@pytest.mark.parametrize("fill_value", [0, 1, 0.32])

0 commit comments

Comments
 (0)