Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 60 additions & 148 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ use arrow_array::types::{
ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
};
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
use arrow_buffer::{
ArrowNativeType, BooleanBuffer, NullBuffer, OffsetBuffer, RunEndBuffer, ScalarBuffer, bit_util,
};
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::ArrayDataBuilder;
use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
use arrow_data::transform::MutableArrayData;
use arrow_schema::*;
Expand Down Expand Up @@ -579,6 +580,14 @@ fn filter_null_mask(
Some((null_count, nulls))
}

/// Filters `nulls` and reuses the computed `null_count` to avoid scanning the bitmap.
fn filter_nulls(nulls: Option<&NullBuffer>, predicate: &FilterPredicate) -> Option<NullBuffer> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making this a method on FilterPredicate? That would make it easier to find / reuse I think.

impl FilterPredicate { 
  fn filter_nulls(&self, nulls:  Option<&NullBuffer>) -> Option<NullBuffer> {
     ...
  }
}

let (null_count, nulls) = filter_null_mask(nulls, predicate)?;
let buffer = BooleanBuffer::new(nulls, 0, predicate.count);

Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please add a safety comment here explaining why this is safe:

Suggested change
Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
// Safety: null_count return from filter_null_mas is correct
Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })

It might also be nice to add a debug assert here to verify in debug builds

debug_assert_eq!(null_count, nulls.num_zeros())

}

/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
let src = buffer.values();
Expand Down Expand Up @@ -624,18 +633,11 @@ fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {

/// `filter` implementation for boolean buffers
fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
let values = filter_bits(array.values(), predicate);

let mut builder = ArrayDataBuilder::new(DataType::Boolean)
.len(predicate.count)
.add_buffer(values);

if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
}
let buffer = filter_bits(array.values(), predicate);
let values = BooleanBuffer::new(buffer, 0, predicate.count);
let nulls = filter_nulls(array.nulls(), predicate);

let data = unsafe { builder.build_unchecked() };
BooleanArray::from(data)
BooleanArray::new(values, nulls)
}

#[inline(never)]
Expand Down Expand Up @@ -681,18 +683,17 @@ fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -
where
T: ArrowPrimitiveType,
{
let values = array.values();
let buffer = filter_native(values, predicate);
let mut builder = ArrayDataBuilder::new(array.data_type().clone())
.len(predicate.count)
.add_buffer(buffer);

if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
let buffer = filter_native(array.values(), predicate);
let values = ScalarBuffer::new(buffer, 0, predicate.count);
let nulls = filter_nulls(array.nulls(), predicate);
let filtered = PrimitiveArray::new(values, nulls);

// Avoid the compatibility check when the physical type already matches.
if array.data_type() == &T::DATA_TYPE {
filtered
} else {
filtered.with_data_type(array.data_type().clone())
}

let data = unsafe { builder.build_unchecked() };
PrimitiveArray::from(data)
}

/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
Expand Down Expand Up @@ -824,17 +825,10 @@ where
IterationStrategy::All | IterationStrategy::None => unreachable!(),
}

let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
.len(predicate.count)
.add_buffer(filter.dst_offsets.into())
.add_buffer(filter.dst_values.into());
let offsets = unsafe { OffsetBuffer::new_unchecked(filter.dst_offsets.into()) };
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it would also be nice to comment about why this is safe (what assumptions it relies on). However, I see the existing code doesn't have a safety comment

// Safety: offsets are correctly constructed

let nulls = filter_nulls(array.nulls(), predicate);

if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
}

let data = unsafe { builder.build_unchecked() };
GenericByteArray::from(data)
unsafe { GenericByteArray::new_unchecked(offsets, filter.dst_values.into(), nulls) }
}

/// `filter` implementation for byte view arrays.
Expand All @@ -843,17 +837,11 @@ fn filter_byte_view<T: ByteViewType>(
predicate: &FilterPredicate,
) -> GenericByteViewArray<T> {
let new_view_buffer = filter_native(array.views(), predicate);
let views = ScalarBuffer::new(new_view_buffer, 0, predicate.count);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here (and other places) you can probably use the unchecked variants too to skip some checks, if we need to get additional speed (ScalarBuffer::new_unchecked)

However, given your PR removes an allocation (the buffers array) I suspect this is already going to be faster and avoiding unsafe is a nice bonus ❤️

let buffers = array.data_buffers().to_vec();
let nulls = filter_nulls(array.nulls(), predicate);

let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
.len(predicate.count)
.add_buffer(new_view_buffer)
.add_buffers(array.data_buffers().to_vec());

if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
}

GenericByteViewArray::from(unsafe { builder.build_unchecked() })
unsafe { GenericByteViewArray::new_unchecked(views, buffers, nulls) }
}

fn filter_fixed_size_binary(
Expand Down Expand Up @@ -902,16 +890,10 @@ fn filter_fixed_size_binary(
}
IterationStrategy::All | IterationStrategy::None => unreachable!(),
};
let mut builder = ArrayDataBuilder::new(array.data_type().clone())
.len(predicate.count)
.add_buffer(buffer.into());

if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
}
let nulls = filter_nulls(array.nulls(), predicate);

let data = unsafe { builder.build_unchecked() };
FixedSizeBinaryArray::from(data)
FixedSizeBinaryArray::new(array.value_length(), buffer.into(), nulls)
}

/// `filter` implementation for dictionaries
Expand Down Expand Up @@ -992,24 +974,16 @@ fn filter_list_view<OffsetType: OffsetSizeTrait>(
let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);

// Filter the nulls
let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
let buffer = BooleanBuffer::new(nulls, 0, predicate.count);

Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
} else {
None
let field = match array.data_type() {
DataType::ListView(field) | DataType::LargeListView(field) => field.clone(),
_ => unreachable!(),
};
let offsets = ScalarBuffer::new(filtered_offsets, 0, predicate.count);
let sizes = ScalarBuffer::new(filtered_sizes, 0, predicate.count);
let values = array.values().clone();
let nulls = filter_nulls(array.nulls(), predicate);

let list_data = ArrayDataBuilder::new(array.data_type().clone())
.nulls(nulls)
.buffers(vec![filtered_offsets, filtered_sizes])
.child_data(vec![array.values().to_data()])
.len(predicate.count);

let list_data = unsafe { list_data.build_unchecked() };

GenericListViewArray::from(list_data)
unsafe { GenericListViewArray::new_unchecked(field, offsets, sizes, values, nulls) }
}

#[cfg(test)]
Expand All @@ -1018,7 +992,6 @@ mod tests {
use arrow_array::builder::*;
use arrow_array::cast::as_run_array;
use arrow_array::types::*;
use arrow_data::ArrayData;
use rand::distr::uniform::{UniformSampler, UniformUsize};
use rand::distr::{Alphanumeric, StandardUniform};
use rand::prelude::*;
Expand Down Expand Up @@ -1494,49 +1467,22 @@ mod tests {

#[test]
fn test_filter_list_array() {
let value_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
.build()
.unwrap();

let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);

let list_data_type =
DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
let list_data = ArrayData::builder(list_data_type)
.len(4)
.add_buffer(value_offsets)
.add_child_data(value_data)
.null_bit_buffer(Some(Buffer::from([0b00000111])))
.build()
.unwrap();

let field = Arc::new(Field::new_list_field(DataType::Int32, false));
let offsets = OffsetBuffer::new(vec![0i64, 3, 6, 8, 8].into());
let value_array = Arc::new(Int32Array::from_iter_values(0..8));
let nulls = Some(NullBuffer::from(vec![true, true, true, false]));
// a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
let a = LargeListArray::from(list_data);
let a = LargeListArray::new(field.clone(), offsets, value_array, nulls);
let b = BooleanArray::from(vec![false, true, false, true]);
let result = filter(&a, &b).unwrap();

// expected: [[3, 4, 5], null]
let value_data = ArrayData::builder(DataType::Int32)
.len(3)
.add_buffer(Buffer::from_slice_ref([3, 4, 5]))
.build()
.unwrap();

let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);

let list_data_type =
DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
let expected = ArrayData::builder(list_data_type)
.len(2)
.add_buffer(value_offsets)
.add_child_data(value_data)
.null_bit_buffer(Some(Buffer::from([0b00000001])))
.build()
.unwrap();
let offsets = OffsetBuffer::new(vec![0i64, 3, 3].into());
let value_array = Arc::new(Int32Array::from_iter_values([3, 4, 5]));
let nulls = Some(NullBuffer::from(vec![true, false]));
let expected: ArrayRef = Arc::new(LargeListArray::new(field, offsets, value_array, nulls));

assert_eq!(&make_array(expected), &result);
assert_eq!(&expected, &result);
}

fn test_case_filter_list_view<T: OffsetSizeTrait>() {
Expand Down Expand Up @@ -1719,14 +1665,7 @@ mod tests {

let truncated_length = mask_len - offset - truncate;

let data = ArrayDataBuilder::new(DataType::Boolean)
.len(truncated_length)
.offset(offset)
.add_buffer(buffer)
.build()
.unwrap();

let filter = BooleanArray::from(data);
let filter = BooleanArray::new(BooleanBuffer::new(buffer, offset, truncated_length), None);

let slice_bits: Vec<_> = SlicesIterator::new(&filter)
.flat_map(|(start, end)| start..end)
Expand Down Expand Up @@ -1949,18 +1888,9 @@ mod tests {

#[test]
fn test_filter_fixed_size_list_arrays() {
let value_data = ArrayData::builder(DataType::Int32)
.len(9)
.add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
.build()
.unwrap();
let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
let list_data = ArrayData::builder(list_data_type)
.len(3)
.add_child_data(value_data)
.build()
.unwrap();
let array = FixedSizeListArray::from(list_data);
let field = Arc::new(Field::new_list_field(DataType::Int32, false));
let value_array = Arc::new(Int32Array::from_iter_values(0..9));
let array = FixedSizeListArray::new(field, 3, value_array, None);

let filter_array = BooleanArray::from(vec![true, false, false]);

Expand Down Expand Up @@ -1996,28 +1926,10 @@ mod tests {

#[test]
fn test_filter_fixed_size_list_arrays_with_null() {
let value_data = ArrayData::builder(DataType::Int32)
.len(10)
.add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
.build()
.unwrap();

// Set null buts for the nested array:
// [[0, 1], null, null, [6, 7], [8, 9]]
// 01011001 00000001
let mut null_bits: [u8; 1] = [0; 1];
bit_util::set_bit(&mut null_bits, 0);
bit_util::set_bit(&mut null_bits, 3);
bit_util::set_bit(&mut null_bits, 4);

let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
let list_data = ArrayData::builder(list_data_type)
.len(5)
.add_child_data(value_data)
.null_bit_buffer(Some(Buffer::from(null_bits)))
.build()
.unwrap();
let array = FixedSizeListArray::from(list_data);
let field = Arc::new(Field::new_list_field(DataType::Int32, false));
let value_array = Arc::new(Int32Array::from_iter_values(0..10));
let nulls = Some(NullBuffer::from(vec![true, false, false, true, true]));
let array = FixedSizeListArray::new(field, 2, value_array, nulls);

let filter_array = BooleanArray::from(vec![true, true, false, true, false]);

Expand Down
Loading