@@ -2072,7 +2072,6 @@ def reduce_slices( # noqa: C901
20722072 else : # iterate over chunks incrementing along last axis
20732073 intersecting_chunks = get_intersecting_chunks (_slice , shape , chunks )
20742074 out_init = False
2075- res_out_init = False
20762075 ratio = np .ceil (np .asarray (shape ) / np .asarray (chunks )).astype (np .int64 )
20772076
20782077 for chunk_slice in intersecting_chunks :
@@ -2177,7 +2176,9 @@ def reduce_slices( # noqa: C901
21772176 result = reduce_op .value .reduce (result , ** reduce_args )
21782177
21792178 if not out_init :
2180- out_ = convert_none_out (result .dtype , reduce_op , reduced_shape )
2179+ out_ , res_out_ = convert_none_out (
2180+ result .dtype , reduce_op , reduced_shape , axis = reduce_args ["axis" ]
2181+ )
21812182 # if reduce_op == ReduceOp.CUMULATIVE_SUM:
21822183 # kahan_sum = np.zeros_like(res_out_)
21832184 if out is not None :
@@ -2187,12 +2188,6 @@ def reduce_slices( # noqa: C901
21872188 out = out_
21882189 out_init = True
21892190
2190- if (reduce_args ["axis" ] is None and not res_out_init ) or (
2191- np .isscalar (reduce_args ["axis" ]) and cslice_subidx [reduce_args ["axis" ]].start == 0
2192- ): # starting reduction again along axis
2193- res_out_ = _get_res_out (result .shape , reduce_args ["axis" ], dtype , reduce_op )
2194- res_out_init = True
2195-
21962191 # Update the output array with the result
21972192 if reduce_op == ReduceOp .ANY :
21982193 out [reduced_slice ] += result
@@ -2201,29 +2196,33 @@ def reduce_slices( # noqa: C901
22012196 elif res_out_ is not None :
22022197 # need lowest index for which optimum attained
22032198 if reduce_op in {ReduceOp .ARGMAX , ReduceOp .ARGMIN }:
2204- cond = (res_out_ == result ) & (result_idx < out [reduced_slice ])
2205- cond |= res_out_ < result if reduce_op == ReduceOp .ARGMAX else res_out_ > result
2199+ cond = (res_out_ [reduced_slice ] == result ) & (result_idx < out [reduced_slice ])
2200+ cond |= (
2201+ res_out_ [reduced_slice ] < result
2202+ if reduce_op == ReduceOp .ARGMAX
2203+ else res_out_ [reduced_slice ] > result
2204+ )
22062205 out [reduced_slice ] = np .where (cond , result_idx , out [reduced_slice ])
2207- res_out_ = np .where (cond , result , res_out_ )
2206+ res_out_ [ reduced_slice ] = np .where (cond , result , res_out_ [ reduced_slice ] )
22082207 else : # CUMULATIVE_SUM or CUMULATIVE_PROD
22092208 idx_result = tuple (
22102209 slice (- 1 , None ) if i == reduce_args ["axis" ] else slice (None , None )
22112210 for i , c in enumerate (reduced_slice )
22122211 )
2213- # idx_lastval = tuple(
2214- # slice(0, 1) if i == reduce_args["axis"] else c for i, c in enumerate(reduced_slice)
2215- # )
2212+ idx_lastval = tuple (
2213+ slice (0 , 1 ) if i == reduce_args ["axis" ] else c for i , c in enumerate (reduced_slice )
2214+ )
22162215 if reduce_op == ReduceOp .CUMULATIVE_SUM :
22172216 # use Kahan summation algorithm for better precision
22182217 # y = res_out_[idx_lastval] - kahan_sum[idx_lastval]
22192218 # t = result + y
22202219 # kahan_sum[idx_lastval] = ((t - result) - y)[idx_result]
22212220 # result = t
2222- result += res_out_
2221+ result += res_out_ [ idx_lastval ]
22232222 else : # CUMULATIVE_PROD
2224- result *= res_out_
2223+ result *= res_out_ [ idx_lastval ]
22252224 out [reduced_slice ] = result
2226- res_out_ = result [idx_result ]
2225+ res_out_ [ idx_lastval ] = result [idx_result ]
22272226 else :
22282227 out [reduced_slice ] = reduce_op .value (out [reduced_slice ], result )
22292228
@@ -2236,7 +2235,7 @@ def reduce_slices( # noqa: C901
22362235 if dtype is None :
22372236 # We have no hint here, so choose a default dtype
22382237 dtype = np .float64
2239- out = convert_none_out (dtype , reduce_op , reduced_shape )
2238+ out , _ = convert_none_out (dtype , reduce_op , reduced_shape )
22402239
22412240 out = out [()] if reduced_shape == () else out # undo dummy dim from inside convert_none_out
22422241 final_mask = tuple (np .where (mask_slice )[0 ])
@@ -2248,31 +2247,8 @@ def reduce_slices( # noqa: C901
22482247 return out
22492248
22502249
2251- def _get_res_out (reduced_shape , axis , dtype , reduce_op ):
2252- reduced_shape = (1 ,) if reduced_shape == () else reduced_shape
2253- # Get res_out to hold running sums along axes for chunks when doing cumulative sums/prods with axis not None
2254- if reduce_op in {ReduceOp .CUMULATIVE_SUM , ReduceOp .CUMULATIVE_PROD }:
2255- temp_shape = tuple (1 if i == axis else s for i , s in enumerate (reduced_shape ))
2256- res_out_ = (
2257- np .zeros (temp_shape , dtype = dtype )
2258- if reduce_op == ReduceOp .CUMULATIVE_SUM
2259- else np .ones (temp_shape , dtype = dtype )
2260- )
2261- elif reduce_op in {ReduceOp .ARGMIN , ReduceOp .ARGMAX }:
2262- temp_shape = reduced_shape
2263- res_out_ = np .ones (temp_shape , dtype = dtype )
2264- if np .issubdtype (dtype , np .integer ):
2265- res_out_ *= np .iinfo (dtype ).max if reduce_op == ReduceOp .ARGMIN else np .iinfo (dtype ).min
2266- elif np .issubdtype (dtype , np .bool ):
2267- res_out_ = res_out_ if reduce_op == ReduceOp .ARGMIN else np .zeros (temp_shape , dtype = dtype )
2268- else :
2269- res_out_ *= np .inf if reduce_op == ReduceOp .ARGMIN else - np .inf
2270- else :
2271- res_out_ = None
2272- return res_out_
2273-
2274-
2275- def convert_none_out (dtype , reduce_op , reduced_shape ):
2250+ def convert_none_out (dtype , reduce_op , reduced_shape , axis = None ):
2251+ out = None
22762252 reduced_shape = (1 ,) if reduced_shape == () else reduced_shape
22772253 # out will be a proper numpy.ndarray
22782254 if reduce_op in {ReduceOp .SUM , ReduceOp .CUMULATIVE_SUM , ReduceOp .PROD , ReduceOp .CUMULATIVE_PROD }:
@@ -2281,6 +2257,17 @@ def convert_none_out(dtype, reduce_op, reduced_shape):
22812257 if reduce_op in {ReduceOp .SUM , ReduceOp .CUMULATIVE_SUM }
22822258 else np .ones (reduced_shape , dtype = dtype )
22832259 )
2260+ # Get res_out to hold running sums along axes for chunks when doing cumulative sums/prods with axis not None
2261+ if reduce_op in {ReduceOp .SUM , ReduceOp .PROD }:
2262+ res_out_ = None
2263+ else :
2264+ temp_shape = tuple (1 if i == axis else s for i , s in enumerate (reduced_shape ))
2265+ res_out_ = (
2266+ np .zeros (temp_shape , dtype = dtype )
2267+ if reduce_op == ReduceOp .CUMULATIVE_SUM
2268+ else np .ones (temp_shape , dtype = dtype )
2269+ )
2270+ out = (out , res_out_ )
22842271 elif reduce_op == ReduceOp .MIN :
22852272 if np .issubdtype (dtype , np .integer ):
22862273 out = np .iinfo (dtype ).max * np .ones (reduced_shape , dtype = dtype )
@@ -2296,8 +2283,15 @@ def convert_none_out(dtype, reduce_op, reduced_shape):
22962283 elif reduce_op == ReduceOp .ALL :
22972284 out = np .ones (reduced_shape , dtype = np .bool_ )
22982285 elif reduce_op in {ReduceOp .ARGMIN , ReduceOp .ARGMAX }:
2299- out = np .zeros (reduced_shape , dtype = blosc2 .DEFAULT_INDEX )
2300- return out
2286+ res_out_ = np .ones (reduced_shape , dtype = dtype )
2287+ if np .issubdtype (dtype , np .integer ):
2288+ res_out_ *= np .iinfo (dtype ).max if reduce_op == ReduceOp .ARGMIN else np .iinfo (dtype ).min
2289+ elif np .issubdtype (dtype , np .bool ):
2290+ res_out_ = res_out_ if reduce_op == ReduceOp .ARGMIN else np .zeros (reduced_shape , dtype = dtype )
2291+ else :
2292+ res_out_ *= np .inf if reduce_op == ReduceOp .ARGMIN else - np .inf
2293+ out = (np .zeros (reduced_shape , dtype = blosc2 .DEFAULT_INDEX ), res_out_ )
2294+ return out if isinstance (out , tuple ) else (out , None )
23012295
23022296
23032297def chunked_eval (
0 commit comments