-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[arrow-select] Replace ArrayData with direct Array construction in filter kernels
#9986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,9 +26,10 @@ use arrow_array::types::{ | |||||||
| ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType, | ||||||||
| }; | ||||||||
| use arrow_array::*; | ||||||||
| use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util}; | ||||||||
| use arrow_buffer::{ | ||||||||
| ArrowNativeType, BooleanBuffer, NullBuffer, OffsetBuffer, RunEndBuffer, ScalarBuffer, bit_util, | ||||||||
| }; | ||||||||
| use arrow_buffer::{Buffer, MutableBuffer}; | ||||||||
| use arrow_data::ArrayDataBuilder; | ||||||||
| use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator}; | ||||||||
| use arrow_data::transform::MutableArrayData; | ||||||||
| use arrow_schema::*; | ||||||||
|
|
@@ -579,6 +580,14 @@ fn filter_null_mask( | |||||||
| Some((null_count, nulls)) | ||||||||
| } | ||||||||
|
|
||||||||
| /// Filters `nulls` and reuses the computed `null_count` to avoid scanning the bitmap. | ||||||||
| fn filter_nulls(nulls: Option<&NullBuffer>, predicate: &FilterPredicate) -> Option<NullBuffer> { | ||||||||
| let (null_count, nulls) = filter_null_mask(nulls, predicate)?; | ||||||||
| let buffer = BooleanBuffer::new(nulls, 0, predicate.count); | ||||||||
|
|
||||||||
| Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we please add a safety comment here explaining why this is safe:
Suggested change
It might also be nice to add a debug assert here to verify in debug builds debug_assert_eq!(null_count, nulls.num_zeros()) |
||||||||
| } | ||||||||
|
|
||||||||
| /// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset` | ||||||||
| fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer { | ||||||||
| let src = buffer.values(); | ||||||||
|
|
@@ -624,18 +633,11 @@ fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer { | |||||||
|
|
||||||||
| /// `filter` implementation for boolean buffers | ||||||||
| fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray { | ||||||||
| let values = filter_bits(array.values(), predicate); | ||||||||
|
|
||||||||
| let mut builder = ArrayDataBuilder::new(DataType::Boolean) | ||||||||
| .len(predicate.count) | ||||||||
| .add_buffer(values); | ||||||||
|
|
||||||||
| if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { | ||||||||
| builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); | ||||||||
| } | ||||||||
| let buffer = filter_bits(array.values(), predicate); | ||||||||
| let values = BooleanBuffer::new(buffer, 0, predicate.count); | ||||||||
| let nulls = filter_nulls(array.nulls(), predicate); | ||||||||
|
|
||||||||
| let data = unsafe { builder.build_unchecked() }; | ||||||||
| BooleanArray::from(data) | ||||||||
| BooleanArray::new(values, nulls) | ||||||||
| } | ||||||||
|
|
||||||||
| #[inline(never)] | ||||||||
|
|
@@ -681,18 +683,17 @@ fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) - | |||||||
| where | ||||||||
| T: ArrowPrimitiveType, | ||||||||
| { | ||||||||
| let values = array.values(); | ||||||||
| let buffer = filter_native(values, predicate); | ||||||||
| let mut builder = ArrayDataBuilder::new(array.data_type().clone()) | ||||||||
| .len(predicate.count) | ||||||||
| .add_buffer(buffer); | ||||||||
|
|
||||||||
| if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { | ||||||||
| builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); | ||||||||
| let buffer = filter_native(array.values(), predicate); | ||||||||
| let values = ScalarBuffer::new(buffer, 0, predicate.count); | ||||||||
| let nulls = filter_nulls(array.nulls(), predicate); | ||||||||
| let filtered = PrimitiveArray::new(values, nulls); | ||||||||
|
|
||||||||
| // Avoid the compatibility check when the physical type already matches. | ||||||||
| if array.data_type() == &T::DATA_TYPE { | ||||||||
| filtered | ||||||||
| } else { | ||||||||
| filtered.with_data_type(array.data_type().clone()) | ||||||||
| } | ||||||||
|
|
||||||||
| let data = unsafe { builder.build_unchecked() }; | ||||||||
| PrimitiveArray::from(data) | ||||||||
| } | ||||||||
|
|
||||||||
| /// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be | ||||||||
|
|
@@ -824,17 +825,10 @@ where | |||||||
| IterationStrategy::All | IterationStrategy::None => unreachable!(), | ||||||||
| } | ||||||||
|
|
||||||||
| let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) | ||||||||
| .len(predicate.count) | ||||||||
| .add_buffer(filter.dst_offsets.into()) | ||||||||
| .add_buffer(filter.dst_values.into()); | ||||||||
| let offsets = unsafe { OffsetBuffer::new_unchecked(filter.dst_offsets.into()) }; | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here it would also be nice to comment about why this is safe (what assumptions it relies on). However, I see the existing code doesn't have a safety comment // Safety: offsets are correctly constructed |
||||||||
| let nulls = filter_nulls(array.nulls(), predicate); | ||||||||
|
|
||||||||
| if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { | ||||||||
| builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); | ||||||||
| } | ||||||||
|
|
||||||||
| let data = unsafe { builder.build_unchecked() }; | ||||||||
| GenericByteArray::from(data) | ||||||||
| unsafe { GenericByteArray::new_unchecked(offsets, filter.dst_values.into(), nulls) } | ||||||||
| } | ||||||||
|
|
||||||||
| /// `filter` implementation for byte view arrays. | ||||||||
|
|
@@ -843,17 +837,11 @@ fn filter_byte_view<T: ByteViewType>( | |||||||
| predicate: &FilterPredicate, | ||||||||
| ) -> GenericByteViewArray<T> { | ||||||||
| let new_view_buffer = filter_native(array.views(), predicate); | ||||||||
| let views = ScalarBuffer::new(new_view_buffer, 0, predicate.count); | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here (and other places) you can probably use the unchecked variants too to skip some checks, if we need to get additional speed ( However, given your PR removes an allocation (the buffers array) I suspect this is already going to be faster and avoiding unsafe is a nice bonus ❤️ |
||||||||
| let buffers = array.data_buffers().to_vec(); | ||||||||
| let nulls = filter_nulls(array.nulls(), predicate); | ||||||||
|
|
||||||||
| let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) | ||||||||
| .len(predicate.count) | ||||||||
| .add_buffer(new_view_buffer) | ||||||||
| .add_buffers(array.data_buffers().to_vec()); | ||||||||
|
|
||||||||
| if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { | ||||||||
| builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); | ||||||||
| } | ||||||||
|
|
||||||||
| GenericByteViewArray::from(unsafe { builder.build_unchecked() }) | ||||||||
| unsafe { GenericByteViewArray::new_unchecked(views, buffers, nulls) } | ||||||||
| } | ||||||||
|
|
||||||||
| fn filter_fixed_size_binary( | ||||||||
|
|
@@ -902,16 +890,10 @@ fn filter_fixed_size_binary( | |||||||
| } | ||||||||
| IterationStrategy::All | IterationStrategy::None => unreachable!(), | ||||||||
| }; | ||||||||
| let mut builder = ArrayDataBuilder::new(array.data_type().clone()) | ||||||||
| .len(predicate.count) | ||||||||
| .add_buffer(buffer.into()); | ||||||||
|
|
||||||||
| if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { | ||||||||
| builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); | ||||||||
| } | ||||||||
| let nulls = filter_nulls(array.nulls(), predicate); | ||||||||
|
|
||||||||
| let data = unsafe { builder.build_unchecked() }; | ||||||||
| FixedSizeBinaryArray::from(data) | ||||||||
| FixedSizeBinaryArray::new(array.value_length(), buffer.into(), nulls) | ||||||||
| } | ||||||||
|
|
||||||||
| /// `filter` implementation for dictionaries | ||||||||
|
|
@@ -992,24 +974,16 @@ fn filter_list_view<OffsetType: OffsetSizeTrait>( | |||||||
| let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate); | ||||||||
| let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate); | ||||||||
|
|
||||||||
| // Filter the nulls | ||||||||
| let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { | ||||||||
| let buffer = BooleanBuffer::new(nulls, 0, predicate.count); | ||||||||
|
|
||||||||
| Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) | ||||||||
| } else { | ||||||||
| None | ||||||||
| let field = match array.data_type() { | ||||||||
| DataType::ListView(field) | DataType::LargeListView(field) => field.clone(), | ||||||||
| _ => unreachable!(), | ||||||||
| }; | ||||||||
| let offsets = ScalarBuffer::new(filtered_offsets, 0, predicate.count); | ||||||||
| let sizes = ScalarBuffer::new(filtered_sizes, 0, predicate.count); | ||||||||
| let values = array.values().clone(); | ||||||||
| let nulls = filter_nulls(array.nulls(), predicate); | ||||||||
|
|
||||||||
| let list_data = ArrayDataBuilder::new(array.data_type().clone()) | ||||||||
| .nulls(nulls) | ||||||||
| .buffers(vec![filtered_offsets, filtered_sizes]) | ||||||||
| .child_data(vec![array.values().to_data()]) | ||||||||
| .len(predicate.count); | ||||||||
|
|
||||||||
| let list_data = unsafe { list_data.build_unchecked() }; | ||||||||
|
|
||||||||
| GenericListViewArray::from(list_data) | ||||||||
| unsafe { GenericListViewArray::new_unchecked(field, offsets, sizes, values, nulls) } | ||||||||
| } | ||||||||
|
|
||||||||
| #[cfg(test)] | ||||||||
|
|
@@ -1018,7 +992,6 @@ mod tests { | |||||||
| use arrow_array::builder::*; | ||||||||
| use arrow_array::cast::as_run_array; | ||||||||
| use arrow_array::types::*; | ||||||||
| use arrow_data::ArrayData; | ||||||||
| use rand::distr::uniform::{UniformSampler, UniformUsize}; | ||||||||
| use rand::distr::{Alphanumeric, StandardUniform}; | ||||||||
| use rand::prelude::*; | ||||||||
|
|
@@ -1494,49 +1467,22 @@ mod tests { | |||||||
|
|
||||||||
| #[test] | ||||||||
| fn test_filter_list_array() { | ||||||||
| let value_data = ArrayData::builder(DataType::Int32) | ||||||||
| .len(8) | ||||||||
| .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
|
|
||||||||
| let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]); | ||||||||
|
|
||||||||
| let list_data_type = | ||||||||
| DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); | ||||||||
| let list_data = ArrayData::builder(list_data_type) | ||||||||
| .len(4) | ||||||||
| .add_buffer(value_offsets) | ||||||||
| .add_child_data(value_data) | ||||||||
| .null_bit_buffer(Some(Buffer::from([0b00000111]))) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
|
|
||||||||
| let field = Arc::new(Field::new_list_field(DataType::Int32, false)); | ||||||||
| let offsets = OffsetBuffer::new(vec![0i64, 3, 6, 8, 8].into()); | ||||||||
| let value_array = Arc::new(Int32Array::from_iter_values(0..8)); | ||||||||
| let nulls = Some(NullBuffer::from(vec![true, true, true, false])); | ||||||||
| // a = [[0, 1, 2], [3, 4, 5], [6, 7], null] | ||||||||
| let a = LargeListArray::from(list_data); | ||||||||
| let a = LargeListArray::new(field.clone(), offsets, value_array, nulls); | ||||||||
| let b = BooleanArray::from(vec![false, true, false, true]); | ||||||||
| let result = filter(&a, &b).unwrap(); | ||||||||
|
|
||||||||
| // expected: [[3, 4, 5], null] | ||||||||
| let value_data = ArrayData::builder(DataType::Int32) | ||||||||
| .len(3) | ||||||||
| .add_buffer(Buffer::from_slice_ref([3, 4, 5])) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
|
|
||||||||
| let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]); | ||||||||
|
|
||||||||
| let list_data_type = | ||||||||
| DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); | ||||||||
| let expected = ArrayData::builder(list_data_type) | ||||||||
| .len(2) | ||||||||
| .add_buffer(value_offsets) | ||||||||
| .add_child_data(value_data) | ||||||||
| .null_bit_buffer(Some(Buffer::from([0b00000001]))) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
| let offsets = OffsetBuffer::new(vec![0i64, 3, 3].into()); | ||||||||
| let value_array = Arc::new(Int32Array::from_iter_values([3, 4, 5])); | ||||||||
| let nulls = Some(NullBuffer::from(vec![true, false])); | ||||||||
| let expected: ArrayRef = Arc::new(LargeListArray::new(field, offsets, value_array, nulls)); | ||||||||
|
|
||||||||
| assert_eq!(&make_array(expected), &result); | ||||||||
| assert_eq!(&expected, &result); | ||||||||
| } | ||||||||
|
|
||||||||
| fn test_case_filter_list_view<T: OffsetSizeTrait>() { | ||||||||
|
|
@@ -1719,14 +1665,7 @@ mod tests { | |||||||
|
|
||||||||
| let truncated_length = mask_len - offset - truncate; | ||||||||
|
|
||||||||
| let data = ArrayDataBuilder::new(DataType::Boolean) | ||||||||
| .len(truncated_length) | ||||||||
| .offset(offset) | ||||||||
| .add_buffer(buffer) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
|
|
||||||||
| let filter = BooleanArray::from(data); | ||||||||
| let filter = BooleanArray::new(BooleanBuffer::new(buffer, offset, truncated_length), None); | ||||||||
|
|
||||||||
| let slice_bits: Vec<_> = SlicesIterator::new(&filter) | ||||||||
| .flat_map(|(start, end)| start..end) | ||||||||
|
|
@@ -1949,18 +1888,9 @@ mod tests { | |||||||
|
|
||||||||
| #[test] | ||||||||
| fn test_filter_fixed_size_list_arrays() { | ||||||||
| let value_data = ArrayData::builder(DataType::Int32) | ||||||||
| .len(9) | ||||||||
| .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8])) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
| let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false); | ||||||||
| let list_data = ArrayData::builder(list_data_type) | ||||||||
| .len(3) | ||||||||
| .add_child_data(value_data) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
| let array = FixedSizeListArray::from(list_data); | ||||||||
| let field = Arc::new(Field::new_list_field(DataType::Int32, false)); | ||||||||
| let value_array = Arc::new(Int32Array::from_iter_values(0..9)); | ||||||||
| let array = FixedSizeListArray::new(field, 3, value_array, None); | ||||||||
|
|
||||||||
| let filter_array = BooleanArray::from(vec![true, false, false]); | ||||||||
|
|
||||||||
|
|
@@ -1996,28 +1926,10 @@ mod tests { | |||||||
|
|
||||||||
| #[test] | ||||||||
| fn test_filter_fixed_size_list_arrays_with_null() { | ||||||||
| let value_data = ArrayData::builder(DataType::Int32) | ||||||||
| .len(10) | ||||||||
| .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
|
|
||||||||
| // Set null buts for the nested array: | ||||||||
| // [[0, 1], null, null, [6, 7], [8, 9]] | ||||||||
| // 01011001 00000001 | ||||||||
| let mut null_bits: [u8; 1] = [0; 1]; | ||||||||
| bit_util::set_bit(&mut null_bits, 0); | ||||||||
| bit_util::set_bit(&mut null_bits, 3); | ||||||||
| bit_util::set_bit(&mut null_bits, 4); | ||||||||
|
|
||||||||
| let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false); | ||||||||
| let list_data = ArrayData::builder(list_data_type) | ||||||||
| .len(5) | ||||||||
| .add_child_data(value_data) | ||||||||
| .null_bit_buffer(Some(Buffer::from(null_bits))) | ||||||||
| .build() | ||||||||
| .unwrap(); | ||||||||
| let array = FixedSizeListArray::from(list_data); | ||||||||
| let field = Arc::new(Field::new_list_field(DataType::Int32, false)); | ||||||||
| let value_array = Arc::new(Int32Array::from_iter_values(0..10)); | ||||||||
| let nulls = Some(NullBuffer::from(vec![true, false, false, true, true])); | ||||||||
| let array = FixedSizeListArray::new(field, 2, value_array, nulls); | ||||||||
|
|
||||||||
| let filter_array = BooleanArray::from(vec![true, true, false, true, false]); | ||||||||
|
|
||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about making this a method on
FilterPredicate? That would make it easier to find / reuse I think.