Skip to content

Commit c741f12

Browse files
committed
Add product import
2 parents d9b1033 + f418502 commit c741f12

File tree

1 file changed

+39
-45
lines changed

1 file changed

+39
-45
lines changed

src/blosc2/lazyexpr.py

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,7 +2072,6 @@ def reduce_slices( # noqa: C901
20722072
else: # iterate over chunks incrementing along last axis
20732073
intersecting_chunks = get_intersecting_chunks(_slice, shape, chunks)
20742074
out_init = False
2075-
res_out_init = False
20762075
ratio = np.ceil(np.asarray(shape) / np.asarray(chunks)).astype(np.int64)
20772076

20782077
for chunk_slice in intersecting_chunks:
@@ -2177,7 +2176,9 @@ def reduce_slices( # noqa: C901
21772176
result = reduce_op.value.reduce(result, **reduce_args)
21782177

21792178
if not out_init:
2180-
out_ = convert_none_out(result.dtype, reduce_op, reduced_shape)
2179+
out_, res_out_ = convert_none_out(
2180+
result.dtype, reduce_op, reduced_shape, axis=reduce_args["axis"]
2181+
)
21812182
# if reduce_op == ReduceOp.CUMULATIVE_SUM:
21822183
# kahan_sum = np.zeros_like(res_out_)
21832184
if out is not None:
@@ -2187,12 +2188,6 @@ def reduce_slices( # noqa: C901
21872188
out = out_
21882189
out_init = True
21892190

2190-
if (reduce_args["axis"] is None and not res_out_init) or (
2191-
np.isscalar(reduce_args["axis"]) and cslice_subidx[reduce_args["axis"]].start == 0
2192-
): # starting reduction again along axis
2193-
res_out_ = _get_res_out(result.shape, reduce_args["axis"], dtype, reduce_op)
2194-
res_out_init = True
2195-
21962191
# Update the output array with the result
21972192
if reduce_op == ReduceOp.ANY:
21982193
out[reduced_slice] += result
@@ -2201,29 +2196,33 @@ def reduce_slices( # noqa: C901
22012196
elif res_out_ is not None:
22022197
# need lowest index for which optimum attained
22032198
if reduce_op in {ReduceOp.ARGMAX, ReduceOp.ARGMIN}:
2204-
cond = (res_out_ == result) & (result_idx < out[reduced_slice])
2205-
cond |= res_out_ < result if reduce_op == ReduceOp.ARGMAX else res_out_ > result
2199+
cond = (res_out_[reduced_slice] == result) & (result_idx < out[reduced_slice])
2200+
cond |= (
2201+
res_out_[reduced_slice] < result
2202+
if reduce_op == ReduceOp.ARGMAX
2203+
else res_out_[reduced_slice] > result
2204+
)
22062205
out[reduced_slice] = np.where(cond, result_idx, out[reduced_slice])
2207-
res_out_ = np.where(cond, result, res_out_)
2206+
res_out_[reduced_slice] = np.where(cond, result, res_out_[reduced_slice])
22082207
else: # CUMULATIVE_SUM or CUMULATIVE_PROD
22092208
idx_result = tuple(
22102209
slice(-1, None) if i == reduce_args["axis"] else slice(None, None)
22112210
for i, c in enumerate(reduced_slice)
22122211
)
2213-
# idx_lastval = tuple(
2214-
# slice(0, 1) if i == reduce_args["axis"] else c for i, c in enumerate(reduced_slice)
2215-
# )
2212+
idx_lastval = tuple(
2213+
slice(0, 1) if i == reduce_args["axis"] else c for i, c in enumerate(reduced_slice)
2214+
)
22162215
if reduce_op == ReduceOp.CUMULATIVE_SUM:
22172216
# use Kahan summation algorithm for better precision
22182217
# y = res_out_[idx_lastval] - kahan_sum[idx_lastval]
22192218
# t = result + y
22202219
# kahan_sum[idx_lastval] = ((t - result) - y)[idx_result]
22212220
# result = t
2222-
result += res_out_
2221+
result += res_out_[idx_lastval]
22232222
else: # CUMULATIVE_PROD
2224-
result *= res_out_
2223+
result *= res_out_[idx_lastval]
22252224
out[reduced_slice] = result
2226-
res_out_ = result[idx_result]
2225+
res_out_[idx_lastval] = result[idx_result]
22272226
else:
22282227
out[reduced_slice] = reduce_op.value(out[reduced_slice], result)
22292228

@@ -2236,7 +2235,7 @@ def reduce_slices( # noqa: C901
22362235
if dtype is None:
22372236
# We have no hint here, so choose a default dtype
22382237
dtype = np.float64
2239-
out = convert_none_out(dtype, reduce_op, reduced_shape)
2238+
out, _ = convert_none_out(dtype, reduce_op, reduced_shape)
22402239

22412240
out = out[()] if reduced_shape == () else out # undo dummy dim from inside convert_none_out
22422241
final_mask = tuple(np.where(mask_slice)[0])
@@ -2248,31 +2247,8 @@ def reduce_slices( # noqa: C901
22482247
return out
22492248

22502249

2251-
def _get_res_out(reduced_shape, axis, dtype, reduce_op):
2252-
reduced_shape = (1,) if reduced_shape == () else reduced_shape
2253-
# Get res_out to hold running sums along axes for chunks when doing cumulative sums/prods with axis not None
2254-
if reduce_op in {ReduceOp.CUMULATIVE_SUM, ReduceOp.CUMULATIVE_PROD}:
2255-
temp_shape = tuple(1 if i == axis else s for i, s in enumerate(reduced_shape))
2256-
res_out_ = (
2257-
np.zeros(temp_shape, dtype=dtype)
2258-
if reduce_op == ReduceOp.CUMULATIVE_SUM
2259-
else np.ones(temp_shape, dtype=dtype)
2260-
)
2261-
elif reduce_op in {ReduceOp.ARGMIN, ReduceOp.ARGMAX}:
2262-
temp_shape = reduced_shape
2263-
res_out_ = np.ones(temp_shape, dtype=dtype)
2264-
if np.issubdtype(dtype, np.integer):
2265-
res_out_ *= np.iinfo(dtype).max if reduce_op == ReduceOp.ARGMIN else np.iinfo(dtype).min
2266-
elif np.issubdtype(dtype, np.bool):
2267-
res_out_ = res_out_ if reduce_op == ReduceOp.ARGMIN else np.zeros(temp_shape, dtype=dtype)
2268-
else:
2269-
res_out_ *= np.inf if reduce_op == ReduceOp.ARGMIN else -np.inf
2270-
else:
2271-
res_out_ = None
2272-
return res_out_
2273-
2274-
2275-
def convert_none_out(dtype, reduce_op, reduced_shape):
2250+
def convert_none_out(dtype, reduce_op, reduced_shape, axis=None):
2251+
out = None
22762252
reduced_shape = (1,) if reduced_shape == () else reduced_shape
22772253
# out will be a proper numpy.ndarray
22782254
if reduce_op in {ReduceOp.SUM, ReduceOp.CUMULATIVE_SUM, ReduceOp.PROD, ReduceOp.CUMULATIVE_PROD}:
@@ -2281,6 +2257,17 @@ def convert_none_out(dtype, reduce_op, reduced_shape):
22812257
if reduce_op in {ReduceOp.SUM, ReduceOp.CUMULATIVE_SUM}
22822258
else np.ones(reduced_shape, dtype=dtype)
22832259
)
2260+
# Get res_out to hold running sums along axes for chunks when doing cumulative sums/prods with axis not None
2261+
if reduce_op in {ReduceOp.SUM, ReduceOp.PROD}:
2262+
res_out_ = None
2263+
else:
2264+
temp_shape = tuple(1 if i == axis else s for i, s in enumerate(reduced_shape))
2265+
res_out_ = (
2266+
np.zeros(temp_shape, dtype=dtype)
2267+
if reduce_op == ReduceOp.CUMULATIVE_SUM
2268+
else np.ones(temp_shape, dtype=dtype)
2269+
)
2270+
out = (out, res_out_)
22842271
elif reduce_op == ReduceOp.MIN:
22852272
if np.issubdtype(dtype, np.integer):
22862273
out = np.iinfo(dtype).max * np.ones(reduced_shape, dtype=dtype)
@@ -2296,8 +2283,15 @@ def convert_none_out(dtype, reduce_op, reduced_shape):
22962283
elif reduce_op == ReduceOp.ALL:
22972284
out = np.ones(reduced_shape, dtype=np.bool_)
22982285
elif reduce_op in {ReduceOp.ARGMIN, ReduceOp.ARGMAX}:
2299-
out = np.zeros(reduced_shape, dtype=blosc2.DEFAULT_INDEX)
2300-
return out
2286+
res_out_ = np.ones(reduced_shape, dtype=dtype)
2287+
if np.issubdtype(dtype, np.integer):
2288+
res_out_ *= np.iinfo(dtype).max if reduce_op == ReduceOp.ARGMIN else np.iinfo(dtype).min
2289+
elif np.issubdtype(dtype, np.bool):
2290+
res_out_ = res_out_ if reduce_op == ReduceOp.ARGMIN else np.zeros(reduced_shape, dtype=dtype)
2291+
else:
2292+
res_out_ *= np.inf if reduce_op == ReduceOp.ARGMIN else -np.inf
2293+
out = (np.zeros(reduced_shape, dtype=blosc2.DEFAULT_INDEX), res_out_)
2294+
return out if isinstance(out, tuple) else (out, None)
23012295

23022296

23032297
def chunked_eval(

0 commit comments

Comments
 (0)