Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dont use take with arrow #2336

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion packages/vaex-core/vaex/array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ def filter(ar, boolean_mask):


def take(ar, indices):
return ar.take(indices)
if isinstance(ar, pa.ChunkedArray):
# Don't use .take in arrow for chunked arrays
# https://issues.apache.org/jira/browse/ARROW-9773
# slice is zero-copy
return pa.concat_arrays(
[ar.slice(i, 1).combine_chunks() for i in indices]
)
elif isinstance(ar, pa.lib.Array):
return ar.take(to_arrow(indices))
else:
return ar[indices]


def slice(ar, offset, length=None):
Expand Down
11 changes: 4 additions & 7 deletions packages/vaex-core/vaex/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import vaex
import vaex.utils
import vaex.cache
from .array_types import supported_array_types, supported_arrow_array_types, string_types, is_string_type
from .array_types import supported_array_types, supported_arrow_array_types, string_types, is_string_type, take


if vaex.utils.has_c_extension:
Expand Down Expand Up @@ -382,10 +382,7 @@ def __getitem__(self, slice):
# arrow and numpy do not like the negative indices, so we set them to 0
take_indices = indices.copy()
take_indices[mask] = 0
if isinstance(ar_unfiltered, supported_arrow_array_types):
ar = ar_unfiltered.take(vaex.array_types.to_arrow(take_indices))
else:
ar = ar_unfiltered[take_indices]
ar = take(ar_unfiltered, take_indices)
assert not np.ma.isMaskedArray(indices)
if self.masked:
# TODO: we probably want to keep this as arrow array if it originally was
Expand Down Expand Up @@ -594,7 +591,7 @@ def _is_stringy(x):

def _to_string_sequence(x, force=True):
if isinstance(x, pa.DictionaryArray):
x = x.dictionary.take(x.indices) # equivalent to PyArrow 5.0.0's dictionary_decode() but backwards compatible
x = take(x.dictionary, x.indices) # equivalent to PyArrow 5.0.0's dictionary_decode() but backwards compatible
if isinstance(x, pa.ChunkedArray):
# turn into pa.Array, TODO: do we want this, this may result in a big mem copy
table = pa.Table.from_arrays([x], ["single"])
Expand Down Expand Up @@ -825,4 +822,4 @@ def get_mask(self):
return self.string_sequence.mask()

def astype(self, type):
return self.to_numpy().astype(type)
return self.to_numpy().astype(type)
4 changes: 2 additions & 2 deletions packages/vaex-core/vaex/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def reduce(self, others: List["TaskPartValueCounts"]):
deletes.append(counter.nan_index)
if vaex.array_types.is_arrow_array(keys):
indices = np.delete(np.arange(len(keys)), deletes)
keys = keys.take(indices)
keys = vaex.array_types.take(keys. indices)
else:
keys = np.delete(keys, deletes)
if not self.dropmissing and counter.has_null:
Expand All @@ -264,7 +264,7 @@ def reduce(self, others: List["TaskPartValueCounts"]):
if not self.ascending:
order = order[::-1]
counts = counts[order]
keys = keys.take(order)
keys = vaex.array_types.take(keys, order)

keys = keys.tolist()
if None in keys:
Expand Down
4 changes: 2 additions & 2 deletions packages/vaex-core/vaex/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def unique(self, expression, return_inverse=False, dropna=False, dropnan=False,
keys = pa.array(self.category_labels(expression))
@delayed
def encode(codes):
used_keys = keys.take(codes)
used_keys = vaex.array_types.take(keys, codes)
return vaex.array_types.convert(used_keys, array_type)
codes = self[expression].index_values().unique(delay=True)
return self._delay(delay, encode(codes))
Expand Down Expand Up @@ -659,7 +659,7 @@ def reduce(a, b):
if isinstance(keys, (vaex.strings.StringList32, vaex.strings.StringList64)):
keys = vaex.strings.to_arrow(keys)
indices = np.delete(np.arange(len(keys)), deletes)
keys = keys.take(indices)
keys = vaex.array_types.take(keys, indices)
else:
keys = np.delete(keys, deletes)
if not dropmissing and hash_map_unique.has_null:
Expand Down
2 changes: 1 addition & 1 deletion packages/vaex-core/vaex/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2493,7 +2493,7 @@ def _map(ar, value_to_index, choices, default_value=None, use_missing=False, axi

ar = vaex.array_types.to_numpy(ar)
indices = value_to_index.map(ar) + 1
values = choices.take(indices)
values = vaex.array_types.take(choices, indices)
if np.ma.isMaskedArray(ar):
mask = np.ma.getmaskarray(ar).copy()
# also mask out the missing (which had -1 and was moved to 0)
Expand Down
10 changes: 5 additions & 5 deletions packages/vaex-core/vaex/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def process(hashmap_unique: vaex.hash.HashMapUnique):
indices = pa.compute.sort_indices(self.bin_values, sort_keys=[("x", "ascending" if ascending else "descending")])
self.sort_indices = vaex.array_types.to_numpy(indices)
# the bin_values will still be pre sorted, maybe that is confusing (implementation detail)
self.bin_values = pa.compute.take(self.bin_values, self.sort_indices)
self.bin_values = vaex.array_types.take(self.bin_values, self.sort_indices)
else:
self.sort_indices = None
self.hashmap_unique = hashmap_unique
Expand Down Expand Up @@ -377,10 +377,10 @@ def compress(ar):
if dtype.is_struct:
# collapse parent struct into our flat struct
for field, ar in zip(parent.bin_values.type, parent.bin_values.flatten()):
bin_values[field.name] = ar.take(indices)
bin_values[field.name] = vaex.array_types.take(ar, indices)
# bin_values[field.name] = pa.DictionaryArray.from_arrays(indices, ar)
else:
bin_values[parent.label] = parent.bin_values.take(indices)
bin_values[parent.label] = vaex.array_types.take(parent.bin_values, indices)
# bin_values[parent.label] = pa.DictionaryArray.from_arrays(indices, parent.bin_values)
logger.info(f"extracing labels of parent groupers done")
return pa.StructArray.from_arrays(bin_values.values(), bin_values.keys())
Expand Down Expand Up @@ -418,7 +418,7 @@ def __init__(self, expression, df=None, sort=False, ascending=True, row_limit=No
if self.sort:
# not pre-sorting is faster
sort_indices = pa.compute.sort_indices(self.bin_values, sort_keys=[("x", "ascending" if ascending else "descending")])
self.bin_values = pa.compute.take(self.bin_values, sort_indices)
self.bin_values = vaex.array_types.take((self.bin_values, sort_indices)
if self.pre_sort:
# we will map from int to int
sort_indices = vaex.array_types.to_numpy(sort_indices)
Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(self, expression, values, keep_other=True, other_value=None, sort=F
values = pa.concat_arrays(values.chunks)
if sort:
indices = pa.compute.sort_indices(values, sort_keys=[("x", "ascending" if ascending else "descending")])
values = pa.compute.take(values, indices)
values = vaex.array_types.take(values, indices)

if self.keep_other:
self.bin_values = pa.array(vaex.array_types.tolist(values) + [other_value])
Expand Down
2 changes: 1 addition & 1 deletion packages/vaex-core/vaex/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def sorted(self, keys=None, ascending=True, indices=None, return_keys=False):
indices = pa.compute.sort_indices(keys, sort_keys=[('x', "ascending" if ascending else "descending")]) if indices is None else indices
# arrow sorts with null last
null_index = -1 if not self.has_null else len(keys)-1
keys = pa.compute.take(keys, indices)
keys = vaex.array_types.take(keys, indices)
fingerprint = self._internal.fingerprint + "-sorted"
if self.dtype_item.is_string:
# TODO: supported 32 bit in hashmap
Expand Down