5151from .utils import (
5252 NUMPY_GE_2_0 ,
5353 _get_chunk_operands ,
54+ _sliced_chunk_iter ,
5455 check_smaller_shape ,
5556 compute_smaller_slice ,
5657 constructors ,
@@ -1084,10 +1085,14 @@ def get_chunk(arr, info, nchunk):
10841085
10851086async def async_read_chunks (arrs , info , queue ):
10861087 loop = asyncio .get_event_loop ()
1087- nchunks = arrs [0 ].schunk .nchunks
1088-
1088+ shape , chunks_ = arrs [0 ].shape , arrs [0 ].chunks
10891089 with concurrent .futures .ThreadPoolExecutor () as executor :
1090- for nchunk in range (nchunks ):
1090+ my_chunk_iter = range (arrs [0 ].schunk .nchunks )
1091+ if len (info ) == 5 :
1092+ if info [- 1 ] is not None :
1093+ my_chunk_iter = _sliced_chunk_iter (chunks_ , (), shape , axis = info [- 1 ], nchunk = True )
1094+ info = info [:4 ]
1095+ for i , nchunk in enumerate (my_chunk_iter ):
10911096 futures = [
10921097 (index , loop .run_in_executor (executor , get_chunk , arr , info , nchunk ))
10931098 for index , arr in enumerate (arrs )
@@ -1100,7 +1105,7 @@ async def async_read_chunks(arrs, info, queue):
11001105 print (f"Exception occurred: { chunk } " )
11011106 raise chunk
11021107 chunks_sorted .append (chunk )
1103- queue .put ((nchunk , chunks_sorted )) # use non-async queue.put()
1108+ queue .put ((i , chunks_sorted )) # use non-async queue.put()
11041109
11051110 queue .put (None ) # signal the end of the chunks
11061111
@@ -1137,7 +1142,7 @@ def read_nchunk(arrs, info):
11371142
11381143
11391144def fill_chunk_operands (
1140- operands , slice_ , chunks_ , full_chunk , aligned , nchunk , iter_disk , chunk_operands , reduc = False
1145+ operands , slice_ , chunks_ , full_chunk , aligned , nchunk , iter_disk , chunk_operands , reduc = False , axis = None
11411146):
11421147 """Retrieve the chunk operands for evaluating an expression.
11431148
@@ -1150,20 +1155,23 @@ def fill_chunk_operands(
11501155 low_mem = os .environ .get ("BLOSC_LOW_MEM" , False )
11511156 # This method is only useful when all operands are NDArray and shows better
11521157 # performance only when at least one of them is persisted on disk
1153- if nchunk == 0 :
1158+ if iter_chunks is None :
11541159 # Initialize the iterator for reading the chunks
11551160 # Take any operand (all should have the same shape and chunks)
11561161 key , arr = next (iter (operands .items ()))
11571162 chunks_idx , _ = get_chunks_idx (arr .shape , arr .chunks )
1158- info = (reduc , aligned [key ], low_mem , chunks_idx )
1163+ info = (reduc , aligned [key ], low_mem , chunks_idx , axis )
11591164 iter_chunks = read_nchunk (list (operands .values ()), info )
11601165 # Run the asynchronous file reading function from a synchronous context
11611166 chunks = next (iter_chunks )
11621167
11631168 for i , (key , value ) in enumerate (operands .items ()):
11641169 # Chunks are already decompressed, so we can use them directly
11651170 if not low_mem :
1166- chunk_operands [key ] = chunks [i ]
1171+ if full_chunk :
1172+ chunk_operands [key ] = chunks [i ]
1173+ else :
1174+ chunk_operands [key ] = value [slice_ ]
11671175 continue
11681176 # Otherwise, we need to decompress them
11691177 if aligned [key ]:
@@ -1568,10 +1576,12 @@ def slices_eval( # noqa: C901
15681576 intersecting_chunks = get_intersecting_chunks (
15691577 _slice , shape , chunks
15701578 ) # if _slice is (), returns all chunks
1579+ ratio = np .ceil (np .asarray (shape ) / np .asarray (chunks )).astype (np .int64 )
15711580
1572- for nchunk , chunk_slice in enumerate ( intersecting_chunks ) :
1581+ for chunk_slice in intersecting_chunks :
15731582 # Check whether current cslice intersects with _slice
15741583 cslice = chunk_slice .raw
1584+ nchunk = builtins .sum ([c .start // chunks [i ] * np .prod (ratio [i + 1 :]) for i , c in enumerate (cslice )])
15751585 if cslice != () and _slice != ():
15761586 # get intersection of chunk and target
15771587 cslice = step_handler (cslice , _slice )
@@ -1889,9 +1899,9 @@ def reduce_slices( # noqa: C901
18891899 if np .any (mask_slice ):
18901900 add_idx = np .cumsum (mask_slice )
18911901 axis = tuple (a + add_idx [a ] for a in axis ) # axis now refers to new shape with dummy dims
1892- if reduce_args ["axis" ] is not None :
1893- # conserve as integer if was not tuple originally
1894- reduce_args ["axis" ] = axis [0 ] if np .isscalar (reduce_args ["axis" ]) else axis
1902+ if reduce_args ["axis" ] is not None :
1903+ # conserve as integer if was not tuple originally
1904+ reduce_args ["axis" ] = axis [0 ] if np .isscalar (reduce_args ["axis" ]) else axis
18951905 if reduce_op in {ReduceOp .CUMULATIVE_SUM , ReduceOp .CUMULATIVE_PROD }:
18961906 reduced_shape = (np .prod (shape_slice ),) if reduce_args ["axis" ] is None else shape_slice
18971907 # if reduce_args["axis"] is None, have to have 1D input array; otherwise, ensure positive scalar
@@ -2057,18 +2067,17 @@ def reduce_slices( # noqa: C901
20572067 # Iterate over the operands and get the chunks
20582068 chunk_operands = {}
20592069 # Check which chunks intersect with _slice
2060- # if chunks has 0 we loop once but fast path is false as gives error (schunk has no chunks)
2061- if (
2062- np .isscalar (reduce_args ["axis" ]) and not fast_path
2063- ): # iterate over chunks incrementing along reduction axis
2070+ if np .isscalar (reduce_args ["axis" ]): # iterate over chunks incrementing along reduction axis
20642071 intersecting_chunks = get_intersecting_chunks (_slice , shape , chunks , axis = reduce_args ["axis" ])
20652072 else : # iterate over chunks incrementing along last axis
20662073 intersecting_chunks = get_intersecting_chunks (_slice , shape , chunks )
20672074 out_init = False
20682075 res_out_init = False
2076+ ratio = np .ceil (np .asarray (shape ) / np .asarray (chunks )).astype (np .int64 )
20692077
2070- for nchunk , chunk_slice in enumerate ( intersecting_chunks ) :
2078+ for chunk_slice in intersecting_chunks :
20712079 cslice = chunk_slice .raw
2080+ nchunk = builtins .sum ([c .start // chunks [i ] * np .prod (ratio [i + 1 :]) for i , c in enumerate (cslice )])
20722081 # Check whether current cslice intersects with _slice
20732082 if cslice != () and _slice != ():
20742083 # get intersection of chunk and target
@@ -2077,6 +2086,8 @@ def reduce_slices( # noqa: C901
20772086 starts = [s .start if s .start is not None else 0 for s in cslice ]
20782087 unit_steps = np .all ([s .step == 1 for s in cslice ])
20792088 cslice_shape = tuple (s .stop - s .start for s in cslice )
2089+ # get local index of part of out that is to be updated
2090+ cslice_subidx = ndindex .ndindex (cslice ).as_subindex (_slice ).raw # if _slice is (), just gives cslice
20802091 if _slice == () and fast_path and unit_steps :
20812092 # Fast path
20822093 full_chunk = cslice_shape == chunks
@@ -2090,12 +2101,11 @@ def reduce_slices( # noqa: C901
20902101 iter_disk ,
20912102 chunk_operands ,
20922103 reduc = True ,
2104+ axis = reduce_args ["axis" ] if np .isscalar (reduce_args ["axis" ]) else None ,
20932105 )
20942106 else :
20952107 _get_chunk_operands (operands , cslice , chunk_operands , shape )
20962108
2097- # get local index of part of out that is to be updated
2098- cslice_subidx = ndindex .ndindex (cslice ).as_subindex (_slice ).raw # if _slice is (), just gives cslice
20992109 if reduce_op in {ReduceOp .CUMULATIVE_PROD , ReduceOp .CUMULATIVE_SUM }:
21002110 reduced_slice = (
21012111 tuple (
@@ -3160,22 +3170,7 @@ def find_args(expr):
31603170
31613171 return value , expression [idx :idx2 ]
31623172
3163- def _compute_expr (self , item , kwargs ): # noqa : C901
3164- # ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp)
3165- # that are implemented in python-blosc2 not in numexpr
3166- global safe_blosc2_globals
3167- if len (safe_blosc2_globals ) == 0 :
3168- # First eval call, fill blosc2_safe_globals for ne_evaluate
3169- safe_blosc2_globals = {"blosc2" : blosc2 }
3170- # Add all first-level blosc2 functions
3171- safe_blosc2_globals .update (
3172- {
3173- name : getattr (blosc2 , name )
3174- for name in dir (blosc2 )
3175- if callable (getattr (blosc2 , name )) and not name .startswith ("_" )
3176- }
3177- )
3178-
3173+ def _compute_expr (self , item , kwargs ):
31793174 if any (method in self .expression for method in eager_funcs ):
31803175 # We have reductions in the expression (probably coming from a string lazyexpr)
31813176 # Also includes slice
0 commit comments