diff --git a/docs/src/guide/migration.md b/docs/src/guide/migration.md index 92c0612930..9b7471ed07 100644 --- a/docs/src/guide/migration.md +++ b/docs/src/guide/migration.md @@ -6,6 +6,13 @@ stable and breaking changes should generally be communicated (via warnings) for give users a chance to migrate. This page documents the breaking changes between releases and gives advice on how to migrate. +## 1.0.0 + +* The `SearchResult` returned by scalar indices must now output information about null values. + Instead of containing a `RowIdTreeMap`, it now contains a `NullableRowIdSet`. Expressions that + resolve to null values must be included in search results in the null set. This ensures that + `NOT` can be applied to index search results correctly. + ## 0.39 * The `lance` crate no longer re-exports utilities from `lance-arrow` such as `RecordBatchExt` or `SchemaExt`. In the diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 32c787ad9f..126858a739 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -1798,13 +1798,14 @@ def test_json_index(): ) -def test_null_handling(tmp_path: Path): +def test_null_handling(): tbl = pa.table( { "x": [1, 2, None, 3], + "y": ["a", "b", "c", None], } ) - dataset = lance.write_dataset(tbl, tmp_path / "dataset") + dataset = lance.write_dataset(tbl, "memory://test") def check(): assert dataset.to_table(filter="x IS NULL").num_rows == 1 @@ -1813,11 +1814,19 @@ def check(): assert dataset.to_table(filter="x < 5").num_rows == 3 assert dataset.to_table(filter="x IN (1, 2)").num_rows == 2 assert dataset.to_table(filter="x IN (1, 2, NULL)").num_rows == 2 + assert dataset.to_table(filter="x > 0 OR (y != 'a')").num_rows == 4 + assert dataset.to_table(filter="x > 0 AND (y != 'a')").num_rows == 1 + assert dataset.to_table(filter="y != 'a'").num_rows == 2 + # NOT should exclude nulls (issue #4756) + assert dataset.to_table(filter="NOT (x < 2)").num_rows == 2 + assert dataset.to_table(filter="NOT (x IN (1, 2))").num_rows == 1 + # Double NOT + assert dataset.to_table(filter="NOT (NOT (x < 2))").num_rows == 1 check() dataset.create_scalar_index("x", index_type="BITMAP") check() - dataset.create_scalar_index("x", index_type="BTREE") + dataset.create_scalar_index("y", index_type="BTREE") check() diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 1a06a3fc11..e0abe5c36d 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -3,8 +3,7 @@ use std::collections::HashSet; use std::io::Write; -use std::iter; -use std::ops::{Range, RangeBounds}; +use std::ops::{Range, RangeBounds, RangeInclusive}; use std::{collections::BTreeMap, io::Read}; use arrow_array::{Array, BinaryArray, GenericBinaryArray}; @@ -17,20 +16,22 @@ use crate::Result; use super::address::RowAddress; -/// A row id mask to select or deselect particular row ids -/// -/// If both the allow_list and the block_list are Some then the only selected -/// row ids are those that are in the allow_list but not in the block_list -/// (the block_list takes precedence) -/// -/// If both the allow_list and the block_list are None (the default) then -/// all row ids are selected -#[derive(Clone, Debug, Default, DeepSizeOf)] -pub struct RowIdMask { - /// If Some then only these row ids are selected - pub allow_list: Option, - /// If Some then these row ids are not selected. - pub block_list: Option, +mod nullable; + +pub use nullable::{NullableRowIdMask, NullableRowIdSet}; + +/// A mask that selects or deselects rows based on an allow-list or block-list. +#[derive(Clone, Debug, DeepSizeOf)] +pub enum RowIdMask { + AllowList(RowIdTreeMap), + BlockList(RowIdTreeMap), +} + +impl Default for RowIdMask { + fn default() -> Self { + // Empty block list means all rows are allowed + Self::BlockList(RowIdTreeMap::new()) + } } impl RowIdMask { @@ -41,124 +42,68 @@ impl RowIdMask { // Create a mask that doesn't allow anything pub fn allow_nothing() -> Self { - Self { - allow_list: Some(RowIdTreeMap::new()), - block_list: None, - } + Self::AllowList(RowIdTreeMap::new()) } // Create a mask from an allow list pub fn from_allowed(allow_list: RowIdTreeMap) -> Self { - Self { - allow_list: Some(allow_list), - block_list: None, - } + Self::AllowList(allow_list) } // Create a mask from a block list pub fn from_block(block_list: RowIdTreeMap) -> Self { - Self { - allow_list: None, - block_list: Some(block_list), + Self::BlockList(block_list) + } + + pub fn block_list(&self) -> Option<&RowIdTreeMap> { + match self { + Self::BlockList(block_list) => Some(block_list), + _ => None, } } - // If there is both a block list and an allow list then collapse into just an allow list - pub fn normalize(self) -> Self { - if let Self { - allow_list: Some(mut allow_list), - block_list: Some(block_list), - } = self - { - allow_list -= &block_list; - Self { - allow_list: Some(allow_list), - block_list: None, - } - } else { - self + pub fn allow_list(&self) -> Option<&RowIdTreeMap> { + match self { + Self::AllowList(allow_list) => Some(allow_list), + _ => None, } } /// True if the row_id is selected by the mask, false otherwise pub fn selected(&self, row_id: u64) -> bool { - match (&self.allow_list, &self.block_list) { - (None, None) => true, - (Some(allow_list), None) => allow_list.contains(row_id), - (None, Some(block_list)) => !block_list.contains(row_id), - (Some(allow_list), Some(block_list)) => { - allow_list.contains(row_id) && !block_list.contains(row_id) - } + match self { + Self::AllowList(allow_list) => allow_list.contains(row_id), + Self::BlockList(block_list) => !block_list.contains(row_id), } } /// Return the indices of the input row ids that were valid pub fn selected_indices<'a>(&self, row_ids: impl Iterator + 'a) -> Vec { - let enumerated_ids = row_ids.enumerate(); - match (&self.block_list, &self.allow_list) { - (Some(block_list), Some(allow_list)) => { - // Only take rows that are both in the allow list and not in the block list - enumerated_ids - .filter(|(_, row_id)| { - !block_list.contains(**row_id) && allow_list.contains(**row_id) - }) - .map(|(idx, _)| idx as u64) - .collect() - } - (Some(block_list), None) => { - // Take rows that are not in the block list - enumerated_ids - .filter(|(_, row_id)| !block_list.contains(**row_id)) - .map(|(idx, _)| idx as u64) - .collect() - } - (None, Some(allow_list)) => { - // Take rows that are in the allow list - enumerated_ids - .filter(|(_, row_id)| allow_list.contains(**row_id)) - .map(|(idx, _)| idx as u64) - .collect() - } - (None, None) => { - // We should not encounter this case because callers should - // check is_empty first. - panic!("selected_indices called but prefilter has nothing to filter with") - } - } + row_ids + .enumerate() + .filter_map(|(idx, row_id)| { + if self.selected(*row_id) { + Some(idx as u64) + } else { + None + } + }) + .collect() } /// Also block the given ids pub fn also_block(self, block_list: RowIdTreeMap) -> Self { - if block_list.is_empty() { - return self; - } - if let Some(existing) = self.block_list { - Self { - block_list: Some(existing | block_list), - allow_list: self.allow_list, - } - } else { - Self { - block_list: Some(block_list), - allow_list: self.allow_list, - } + match self { + Self::AllowList(allow_list) => Self::AllowList(allow_list - block_list), + Self::BlockList(existing) => Self::BlockList(existing | block_list), } } /// Also allow the given ids pub fn also_allow(self, allow_list: RowIdTreeMap) -> Self { - if let Some(existing) = self.allow_list { - Self { - block_list: self.block_list, - allow_list: Some(existing | allow_list), - } - } else { - Self { - block_list: self.block_list, - // allow_list = None means "all rows allowed" and so allowing - // more rows is meaningless - allow_list: None, - } + match self { + Self::AllowList(existing) => Self::AllowList(existing | allow_list), + Self::BlockList(block_list) => Self::BlockList(block_list - allow_list), } } @@ -175,13 +120,17 @@ impl RowIdMask { /// We serialize this as a variable length binary array with two items. The first item /// is the block list and the second item is the allow list. pub fn into_arrow(&self) -> Result { - let block_list_length = self - .block_list + // NOTE: This serialization format must be stable as it is used in IPC. + let (block_list, allow_list) = match self { + Self::AllowList(allow_list) => (None, Some(allow_list)), + Self::BlockList(block_list) => (Some(block_list), None), + }; + + let block_list_length = block_list .as_ref() .map(|bl| bl.serialized_size()) .unwrap_or(0); - let allow_list_length = self - .allow_list + let allow_list_length = allow_list .as_ref() .map(|al| al.serialized_size()) .unwrap_or(0); @@ -189,11 +138,11 @@ impl RowIdMask { let offsets = OffsetBuffer::from_lengths(lengths); let mut value_bytes = vec![0; block_list_length + allow_list_length]; let mut validity = vec![false, false]; - if let Some(block_list) = &self.block_list { + if let Some(block_list) = &block_list { validity[0] = true; block_list.serialize_into(&mut value_bytes[0..])?; } - if let Some(allow_list) = &self.allow_list { + if let Some(allow_list) = &allow_list { validity[1] = true; allow_list.serialize_into(&mut value_bytes[block_list_length..])?; } @@ -217,65 +166,40 @@ impl RowIdMask { Some(RowIdTreeMap::deserialize_from(array.value(1))) } .transpose()?; - Ok(Self { - block_list, - allow_list, - }) + + let res = match (block_list, allow_list) { + (Some(bl), None) => Self::BlockList(bl), + (None, Some(al)) => Self::AllowList(al), + (Some(block), Some(allow)) => Self::AllowList(allow).also_block(block), + (None, None) => Self::all_rows(), + }; + Ok(res) } /// Return the maximum number of row ids that could be selected by this mask /// - /// Will be None if there is no allow list + /// Will be None if this is a BlockList (unbounded) pub fn max_len(&self) -> Option { - if let Some(allow_list) = &self.allow_list { - // If there is a block list we could theoretically intersect the two - // but it's not clear if that is worth the effort. Feel free to add later. - allow_list.len() - } else { - None + match self { + Self::AllowList(selection) => selection.len(), + Self::BlockList(_) => None, } } /// Iterate over the row ids that are selected by the mask /// - /// This is only possible if there is an allow list and neither the - /// allow list nor the block list contain any "full fragment" blocks. - /// - /// TODO: We could probably still iterate efficiently even if the block - /// list contains "full fragment" blocks but that would require some - /// extra logic. + /// This is only possible if this is an AllowList and the maps don't contain + /// any "full fragment" blocks. pub fn iter_ids(&self) -> Option + '_>> { - if let Some(mut allow_iter) = self.allow_list.as_ref().and_then(|list| list.row_ids()) { - if let Some(block_list) = &self.block_list { - if let Some(block_iter) = block_list.row_ids() { - let mut block_iter = block_iter.peekable(); - Some(Box::new(iter::from_fn(move || { - for allow_id in allow_iter.by_ref() { - while let Some(block_id) = block_iter.peek() { - if *block_id >= allow_id { - break; - } - block_iter.next(); - } - if let Some(block_id) = block_iter.peek() { - if *block_id == allow_id { - continue; - } - } - return Some(allow_id); - } - None - }))) + match self { + Self::AllowList(allow_list) => { + if let Some(allow_iter) = allow_list.row_ids() { + Some(Box::new(allow_iter)) } else { - // There is a block list but we can't iterate over it, give up None } - } else { - // There is no block list, use the allow list - Some(Box::new(allow_iter)) } - } else { - None + Self::BlockList(_) => None, // Can't iterate over block list } } } @@ -284,9 +208,9 @@ impl std::ops::Not for RowIdMask { type Output = Self; fn not(self) -> Self::Output { - Self { - block_list: self.allow_list, - allow_list: self.block_list, + match self { + Self::AllowList(allow_list) => Self::BlockList(allow_list), + Self::BlockList(block_list) => Self::AllowList(block_list), } } } @@ -295,21 +219,11 @@ impl std::ops::BitAnd for RowIdMask { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { - let block_list = match (self.block_list, rhs.block_list) { - (None, None) => None, - (Some(lhs), None) => Some(lhs), - (None, Some(rhs)) => Some(rhs), - (Some(lhs), Some(rhs)) => Some(lhs | rhs), - }; - let allow_list = match (self.allow_list, rhs.allow_list) { - (None, None) => None, - (Some(lhs), None) => Some(lhs), - (None, Some(rhs)) => Some(rhs), - (Some(lhs), Some(rhs)) => Some(lhs & rhs), - }; - Self { - block_list, - allow_list, + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a & b), + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => Self::AllowList(allow - block), + (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a | b), } } } @@ -318,44 +232,11 @@ impl std::ops::BitOr for RowIdMask { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { - let this = self.normalize(); - let rhs = rhs.normalize(); - let block_list = if let Some(mut self_block_list) = this.block_list { - match (&rhs.allow_list, rhs.block_list) { - // If RHS is allow all, then our block list disappears - (None, None) => None, - // If RHS is allow list, remove allowed from our block list - (Some(allow_list), None) => { - self_block_list -= allow_list; - Some(self_block_list) - } - // If RHS is block list, intersect - (None, Some(block_list)) => Some(self_block_list & block_list), - // We normalized to avoid this path - (Some(_), Some(_)) => unreachable!(), - } - } else if let Some(mut rhs_block_list) = rhs.block_list { - if let Some(allow_list) = &this.allow_list { - rhs_block_list -= allow_list; - Some(rhs_block_list) - } else { - Some(rhs_block_list) - } - } else { - None - }; - - let allow_list = match (this.allow_list, rhs.allow_list) { - (None, None) => None, - // Remember that an allow list of None means "all rows" and - // so "all rows" | "some rows" is always "all rows" - (Some(_), None) => None, - (None, Some(_)) => None, - (Some(lhs), Some(rhs)) => Some(lhs | rhs), - }; - Self { - block_list, - allow_list, + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a | b), + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => Self::BlockList(block - allow), + (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a & b), } } } @@ -679,14 +560,16 @@ impl RowIdTreeMap { /// Apply a mask to the row ids /// - /// If there is an allow list then this will intersect the set with the allow list - /// If there is a block list then this will subtract the block list from the set + /// For AllowList: only keep rows that are in the selection and not null + /// For BlockList: remove rows that are blocked (not null) and remove nulls pub fn mask(&mut self, mask: &RowIdMask) { - if let Some(allow_list) = &mask.allow_list { - *self &= allow_list; - } - if let Some(block_list) = &mask.block_list { - *self -= block_list; + match mask { + RowIdMask::AllowList(allow_list) => { + *self &= allow_list; + } + RowIdMask::BlockList(block_list) => { + *self -= block_list; + } } } @@ -720,8 +603,23 @@ impl std::ops::BitOr for RowIdTreeMap { } } +impl std::ops::BitOr<&Self> for RowIdTreeMap { + type Output = Self; + + fn bitor(mut self, rhs: &Self) -> Self::Output { + self |= rhs; + self + } +} + impl std::ops::BitOrAssign for RowIdTreeMap { fn bitor_assign(&mut self, rhs: Self) { + *self |= &rhs; + } +} + +impl std::ops::BitOrAssign<&Self> for RowIdTreeMap { + fn bitor_assign(&mut self, rhs: &Self) { for (fragment, rhs_set) in &rhs.inner { let lhs_set = self.inner.get_mut(fragment); if let Some(lhs_set) = lhs_set { @@ -754,6 +652,21 @@ impl std::ops::BitAnd for RowIdTreeMap { } } +impl std::ops::BitAnd<&Self> for RowIdTreeMap { + type Output = Self; + + fn bitand(mut self, rhs: &Self) -> Self::Output { + self &= rhs; + self + } +} + +impl std::ops::BitAndAssign for RowIdTreeMap { + fn bitand_assign(&mut self, rhs: Self) { + *self &= &rhs; + } +} + impl std::ops::BitAndAssign<&Self> for RowIdTreeMap { fn bitand_assign(&mut self, rhs: &Self) { // Remove fragment that aren't on the RHS @@ -792,6 +705,15 @@ impl std::ops::Sub for RowIdTreeMap { } } +impl std::ops::Sub<&Self> for RowIdTreeMap { + type Output = Self; + + fn sub(mut self, rhs: &Self) -> Self { + self -= rhs; + self + } +} + impl std::ops::SubAssign<&Self> for RowIdTreeMap { fn sub_assign(&mut self, rhs: &Self) { for (fragment, rhs_set) in &rhs.inner { @@ -865,6 +787,14 @@ impl From> for RowIdTreeMap { } } +impl From> for RowIdTreeMap { + fn from(range: RangeInclusive) -> Self { + let mut map = Self::default(); + map.insert_range(range); + map + } +} + impl From for RowIdTreeMap { fn from(roaring: RoaringTreemap) -> Self { let mut inner = BTreeMap::new(); @@ -966,14 +896,10 @@ mod tests { fn test_logical_or() { let allow1 = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[5, 6, 7, 8, 9])); let block1 = RowIdMask::from_block(RowIdTreeMap::from_iter(&[5, 6])); - let mixed1 = allow1 - .clone() - .also_block(block1.block_list.as_ref().unwrap().clone()); + let mixed1 = allow1.clone().also_block(RowIdTreeMap::from_iter(&[5, 6])); let allow2 = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[2, 3, 4, 5, 6, 7, 8])); let block2 = RowIdMask::from_block(RowIdTreeMap::from_iter(&[4, 5])); - let mixed2 = allow2 - .clone() - .also_block(block2.block_list.as_ref().unwrap().clone()); + let mixed2 = allow2.clone().also_block(RowIdTreeMap::from_iter(&[4, 5])); fn check(lhs: &RowIdMask, rhs: &RowIdMask, expected: &[u64]) { for mask in [lhs.clone() | rhs.clone(), rhs.clone() | lhs.clone()] { @@ -1008,6 +934,100 @@ mod tests { check(&block2, &mixed2, &[0, 1, 2, 3, 6, 7, 8, 9]); } + #[test] + fn test_deserialize_legacy_format() { + // Test that we can deserialize the old format where both allow_list + // and block_list could be present in the serialized form. + // + // The old format (before this PR) used a struct with both allow_list and block_list + // fields. The new format uses an enum. The deserialization code should handle + // the case where both lists are present by converting to AllowList(allow - block). + + // Create the RowIdTreeMaps and serialize them directly + let allow = RowIdTreeMap::from_iter(&[1, 2, 3, 4, 5, 10, 15]); + let block = RowIdTreeMap::from_iter(&[2, 4, 15]); + + // Serialize using the stable RowIdTreeMap serialization format + let block_bytes = { + let mut buf = Vec::with_capacity(block.serialized_size()); + block.serialize_into(&mut buf).unwrap(); + buf + }; + let allow_bytes = { + let mut buf = Vec::with_capacity(allow.serialized_size()); + allow.serialize_into(&mut buf).unwrap(); + buf + }; + + // Construct a binary array with both values present (simulating old format) + let old_format_array = + BinaryArray::from_opt_vec(vec![Some(&block_bytes), Some(&allow_bytes)]); + + // Deserialize - should handle this by creating AllowList(allow - block) + let deserialized = RowIdMask::from_arrow(&old_format_array).unwrap(); + + // The expected result: AllowList([1, 2, 3, 4, 5, 10, 15] - [2, 4, 15]) = [1, 3, 5, 10] + let expected_rows = vec![1u64, 3, 5, 10]; + for row in &expected_rows { + assert!( + deserialized.selected(*row), + "Row {} should be selected", + row + ); + } + + // Verify blocked rows are not selected + assert!(!deserialized.selected(2), "Row 2 should be blocked"); + assert!(!deserialized.selected(4), "Row 4 should be blocked"); + assert!(!deserialized.selected(15), "Row 15 should be blocked"); + + // Verify it's an AllowList variant + assert!( + deserialized.allow_list().is_some(), + "Should deserialize to AllowList variant" + ); + } + + #[test] + fn test_deserialize_legacy_empty_lists() { + // Test edge cases with None values in old format + + // Case 1: Both None (should become all_rows) + let array = BinaryArray::from_opt_vec(vec![None, None]); + let mask = RowIdMask::from_arrow(&array).unwrap(); + assert!(mask.selected(0)); + assert!(mask.selected(100)); + assert!(mask.selected(u64::MAX)); + + // Case 2: Only block list (no allow list) + let block = RowIdTreeMap::from_iter(&[5, 10]); + let block_bytes = { + let mut buf = Vec::with_capacity(block.serialized_size()); + block.serialize_into(&mut buf).unwrap(); + buf + }; + let array = BinaryArray::from_opt_vec(vec![Some(&block_bytes[..]), None]); + let mask = RowIdMask::from_arrow(&array).unwrap(); + assert!(mask.selected(0)); + assert!(!mask.selected(5)); + assert!(!mask.selected(10)); + assert!(mask.selected(15)); + + // Case 3: Only allow list (no block list) + let allow = RowIdTreeMap::from_iter(&[5, 10]); + let allow_bytes = { + let mut buf = Vec::with_capacity(allow.serialized_size()); + allow.serialize_into(&mut buf).unwrap(); + buf + }; + let array = BinaryArray::from_opt_vec(vec![None, Some(&allow_bytes[..])]); + let mask = RowIdMask::from_arrow(&array).unwrap(); + assert!(!mask.selected(0)); + assert!(mask.selected(5)); + assert!(mask.selected(10)); + assert!(!mask.selected(15)); + } + #[test] fn test_map_insert_range() { let ranges = &[ @@ -1224,52 +1244,4 @@ mod tests { } } - - #[test] - fn test_iter_ids() { - let mut mask = RowIdMask::default(); - assert!(mask.iter_ids().is_none()); - - // Test with just an allow list - let mut allow_list = RowIdTreeMap::default(); - allow_list.extend([1, 5, 10].iter().copied()); - mask.allow_list = Some(allow_list); - - let ids: Vec<_> = mask.iter_ids().unwrap().collect(); - assert_eq!( - ids, - vec![ - RowAddress::new_from_parts(0, 1), - RowAddress::new_from_parts(0, 5), - RowAddress::new_from_parts(0, 10) - ] - ); - - // Test with both allow list and block list - let mut block_list = RowIdTreeMap::default(); - block_list.extend([5].iter().copied()); - mask.block_list = Some(block_list); - - let ids: Vec<_> = mask.iter_ids().unwrap().collect(); - assert_eq!( - ids, - vec![ - RowAddress::new_from_parts(0, 1), - RowAddress::new_from_parts(0, 10) - ] - ); - - // Test with full fragment in block list - let mut block_list = RowIdTreeMap::default(); - block_list.insert_fragment(0); - mask.block_list = Some(block_list); - assert!(mask.iter_ids().is_none()); - - // Test with full fragment in allow list - mask.block_list = None; - let mut allow_list = RowIdTreeMap::default(); - allow_list.insert_fragment(0); - mask.allow_list = Some(allow_list); - assert!(mask.iter_ids().is_none()); - } } diff --git a/rust/lance-core/src/utils/mask/nullable.rs b/rust/lance-core/src/utils/mask/nullable.rs new file mode 100644 index 0000000000..f17808d6e7 --- /dev/null +++ b/rust/lance-core/src/utils/mask/nullable.rs @@ -0,0 +1,457 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use deepsize::DeepSizeOf; + +use super::{RowIdMask, RowIdTreeMap}; + +/// A set of row ids, with optional set of nulls. +/// +/// This is often a result of a filter, where `selected` represents the rows that +/// passed the filter, and `nulls` represents the rows where the filter evaluated +/// to null. For example, in SQL `NULL > 5` evaluates to null. This is distinct +/// from being deselected to support proper three-valued logic for NOT. +/// (`NOT FALSE` is TRUE, `NOT TRUE` is FALSE, but `NOT NULL` is NULL.) +#[derive(Clone, Debug, Default, DeepSizeOf)] +pub struct NullableRowIdSet { + selected: RowIdTreeMap, + nulls: RowIdTreeMap, +} + +impl NullableRowIdSet { + /// Create a new RowSelection from selected rows and null rows. + /// + /// `nulls` may have overlap with `selected`. Rows in `nulls` are considered NULL, + /// even if they are also in `selected`. + pub fn new(selected: RowIdTreeMap, nulls: RowIdTreeMap) -> Self { + Self { selected, nulls } + } + + pub fn with_nulls(mut self, nulls: RowIdTreeMap) -> Self { + self.nulls = nulls; + self + } + + /// Create an empty selection. Alias for [Default::default] + pub fn empty() -> Self { + Default::default() + } + + pub fn is_empty(&self) -> bool { + self.selected.is_empty() + } + + /// Check if a row_id is selected (TRUE) + pub fn selected(&self, row_id: u64) -> bool { + self.selected.contains(row_id) && !self.nulls.contains(row_id) + } + + /// Get the selected rows (TRUE or NULL) + pub fn selected_rows(&self) -> &RowIdTreeMap { + &self.selected + } + + /// Get the null rows + pub fn null_rows(&self) -> &RowIdTreeMap { + &self.nulls + } + + /// Get the TRUE rows (selected but not null) + pub fn true_rows(&self) -> RowIdTreeMap { + self.selected.clone() - self.nulls.clone() + } + + pub fn union_all(selections: &[Self]) -> Self { + let selected = RowIdTreeMap::union_all( + &selections + .iter() + .map(|s| &s.selected) + .collect::>(), + ); + let nulls = RowIdTreeMap::union_all( + &selections + .iter() + .map(|s| &s.nulls) + .collect::>(), + ); + Self { selected, nulls } + } +} + +impl PartialEq for NullableRowIdSet { + fn eq(&self, other: &Self) -> bool { + self.true_rows() == other.true_rows() && self.nulls == other.nulls + } +} + +impl std::ops::BitAndAssign<&Self> for NullableRowIdSet { + fn bitand_assign(&mut self, rhs: &Self) { + self.nulls = if self.nulls.is_empty() && rhs.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + (self.nulls.clone() & &rhs.nulls) // null and null -> null + | (self.nulls.clone() & &rhs.selected) // null and true -> null + | (rhs.nulls.clone() & &self.selected) // true and null -> null + }; + + self.selected &= &rhs.selected; + } +} + +impl std::ops::BitOrAssign<&Self> for NullableRowIdSet { + fn bitor_assign(&mut self, rhs: &Self) { + self.nulls = if self.nulls.is_empty() && rhs.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let true_rows = + (self.selected.clone() - &self.nulls) | (rhs.selected.clone() - &rhs.nulls); + (self.nulls.clone() | &rhs.nulls) - true_rows + }; + + self.selected |= &rhs.selected; + } +} + +/// A version of [`RowIdMask`] that supports nulls. +/// +/// This mask handles three-valued logic for SQL expressions, where a filter can +/// evaluate to TRUE, FALSE, or NULL. The `selected` set includes rows that are +/// TRUE or NULL. The `nulls` set includes rows that are NULL. +#[derive(Clone, Debug)] +pub enum NullableRowIdMask { + AllowList(NullableRowIdSet), + BlockList(NullableRowIdSet), +} + +impl NullableRowIdMask { + pub fn selected(&self, row_id: u64) -> bool { + match self { + Self::AllowList(NullableRowIdSet { selected, nulls }) => { + selected.contains(row_id) && !nulls.contains(row_id) + } + Self::BlockList(NullableRowIdSet { selected, nulls }) => { + !selected.contains(row_id) && !nulls.contains(row_id) + } + } + } + + pub fn drop_nulls(self) -> RowIdMask { + match self { + Self::AllowList(NullableRowIdSet { selected, nulls }) => { + RowIdMask::AllowList(selected - nulls) + } + Self::BlockList(NullableRowIdSet { selected, nulls }) => { + RowIdMask::BlockList(selected | nulls) + } + } + } +} + +impl std::ops::Not for NullableRowIdMask { + type Output = Self; + + fn not(self) -> Self::Output { + match self { + Self::AllowList(set) => Self::BlockList(set), + Self::BlockList(set) => Self::AllowList(set), + } + } +} + +impl std::ops::BitAnd for NullableRowIdMask { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + // Null handling: + // * null and true -> null + // * null and null -> null + // * null and false -> false + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + (a.nulls.clone() & &b.nulls) // null and null -> null + | (a.nulls & &b.selected) // null and true -> null + | (b.nulls & &a.selected) // true and null -> null + }; + let selected = a.selected & b.selected; + Self::AllowList(NullableRowIdSet { selected, nulls }) + } + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => { + let nulls = if allow.nulls.is_empty() && block.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + (allow.nulls.clone() & &block.nulls) // null and null -> null + | (allow.nulls - &block.selected) // null and true -> null + | (block.nulls & &allow.selected) // true and null -> null + }; + let selected = allow.selected - block.selected; + Self::AllowList(NullableRowIdSet { selected, nulls }) + } + (Self::BlockList(a), Self::BlockList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + (a.nulls.clone() & &b.nulls) // null and null -> null + | (a.nulls - &b.selected) // null and true -> null + | (b.nulls - &a.selected) // true and null -> null + }; + let selected = a.selected | b.selected; + Self::BlockList(NullableRowIdSet { selected, nulls }) + } + } + } +} + +impl std::ops::BitOr for NullableRowIdMask { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + // Null handling: + // * null or true -> true + // * null or null -> null + // * null or false -> null + match (self, rhs) { + (Self::AllowList(a), Self::AllowList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let true_rows = + (a.selected.clone() - &a.nulls) | (b.selected.clone() - &b.nulls); + (a.nulls | b.nulls) - true_rows + }; + let selected = (a.selected | b.selected) | &nulls; + Self::AllowList(NullableRowIdSet { selected, nulls }) + } + (Self::AllowList(allow), Self::BlockList(block)) + | (Self::BlockList(block), Self::AllowList(allow)) => { + let nulls = if allow.nulls.is_empty() && block.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let allow_true = allow.selected.clone() - &allow.nulls; + ((allow.nulls | block.nulls) & block.selected.clone()) - allow_true + }; + let selected = (block.selected - allow.selected) | &nulls; + Self::BlockList(NullableRowIdSet { selected, nulls }) + } + (Self::BlockList(a), Self::BlockList(b)) => { + let nulls = if a.nulls.is_empty() && b.nulls.is_empty() { + RowIdTreeMap::new() // Fast path + } else { + // null or null -> null (excluding rows that are true in either) + let false_rows = + (a.selected.clone() - &a.nulls) & (b.selected.clone() - &b.nulls); + (a.nulls | &b.nulls) - false_rows + }; + let selected = (a.selected & b.selected) | &nulls; + Self::BlockList(NullableRowIdSet { selected, nulls }) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_not_with_nulls() { + // Test case from issue #4756: x != 5 on data [0, 5, null] + // x = 5 should return: AllowList with selected=[1,2], nulls=[2] + // NOT(x = 5) should return: BlockList with selected=[1,2], nulls=[2] + // selected() should return TRUE for row 0, FALSE for rows 1 and 2 + let mask = NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[1, 2]), // rows where x==5 or x==null + RowIdTreeMap::from_iter(&[2]), // row where x is null + )); + + let not_mask = !mask; + + // Row 0: should be selected (x=0, which is != 5) + assert!( + not_mask.selected(0), + "Row 0 (x=0) should be selected for x != 5" + ); + + // Row 1: should NOT be selected (x=5, which is == 5) + assert!( + !not_mask.selected(1), + "Row 1 (x=5) should NOT be selected for x != 5" + ); + + // Row 2: should NOT be selected (x=null, comparison result is null) + assert!( + !not_mask.selected(2), + "Row 2 (x=null) should NOT be selected for x != 5" + ); + } + + #[test] + fn test_and_with_nulls() { + // Test Kleene AND logic: true AND null = null, false AND null = false + + // Case 1: TRUE mask AND mask with nulls + let true_mask = NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2, 3, 4]), // All TRUE + RowIdTreeMap::new(), // No nulls + )); + let null_mask = NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2, 3, 4]), // TRUE or NULL + RowIdTreeMap::from_iter(&[1, 3]), // NULL rows + )); + let result = true_mask & null_mask.clone(); + + // TRUE AND TRUE = TRUE + assert!(result.selected(0)); + assert!(result.selected(2)); + assert!(result.selected(4)); + // TRUE AND NULL = NULL (filtered out) + assert!(!result.selected(1)); + assert!(!result.selected(3)); + + // Case 2: FALSE mask AND mask with nulls + let false_mask = NullableRowIdMask::BlockList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2, 3, 4]), // All FALSE + RowIdTreeMap::new(), // No nulls + )); + let result = false_mask & null_mask; + + // FALSE AND anything = FALSE + assert!(!result.selected(0)); + assert!(!result.selected(1)); + assert!(!result.selected(2)); + assert!(!result.selected(3)); + assert!(!result.selected(4)); + + // Case 3: Both masks have nulls - union of null sets + let mask1 = NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2]), // TRUE or NULL + RowIdTreeMap::from_iter(&[1]), // NULL rows + )); + let mask2 = NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 2, 3]), // TRUE or NULL + RowIdTreeMap::from_iter(&[2]), // NULL rows + )); + let result = mask1 & mask2; + + // Only row 0 is TRUE in both + assert!(result.selected(0)); + // Rows 1, 2 are null in at least one + assert!(!result.selected(1)); + assert!(!result.selected(2)); + // Row 3 is not in first mask's selected + assert!(!result.selected(3)); + } + + #[test] + fn test_or_with_nulls() { + // Test Kleene OR logic: true OR null = true, false OR null = null + + // Case 1: FALSE mask OR mask with nulls + let false_mask = NullableRowIdMask::BlockList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2]), // All FALSE + RowIdTreeMap::new(), // No nulls + )); + let null_mask = NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2]), // TRUE or NULL + RowIdTreeMap::from_iter(&[1, 2]), // NULL rows + )); + let result = false_mask | null_mask.clone(); + + // FALSE OR TRUE = TRUE + assert!(result.selected(0)); + // FALSE OR NULL = NULL (filtered out) + assert!(!result.selected(1)); + assert!(!result.selected(2)); + + // Case 2: TRUE mask OR mask with nulls + let true_mask = NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2]), // All TRUE + RowIdTreeMap::new(), // No nulls + )); + let result = true_mask | null_mask; + + // TRUE OR anything = TRUE + assert!(result.selected(0)); + assert!(result.selected(1)); + assert!(result.selected(2)); + + // Case 3: Both have nulls + let mask1 = NullableRowIdMask::BlockList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2, 3]), // FALSE or NULL + RowIdTreeMap::from_iter(&[1, 2]), // NULL rows + )); + let mask2 = NullableRowIdMask::BlockList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[0, 1, 2, 3]), // FALSE or NULL + RowIdTreeMap::from_iter(&[2, 3]), // NULL rows + )); + let result = mask1 | mask2; + + // Row 0 is FALSE in both + assert!(!result.selected(0)); + // Row 1 is NULL in first, FALSE in second -> NULL + assert!(!result.selected(1)); + // Row 2 is NULL in both -> NULL + assert!(!result.selected(2)); + // Row 3 is FALSE in first, NULL in second -> NULL + assert!(!result.selected(3)); + } + + #[test] + fn test_row_selection_bit_or() { + // [T, N, T, N, F, F, F] + let left = NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[1, 2, 3, 4]), + RowIdTreeMap::from_iter(&[2, 4]), + ); + // [F, F, T, N, T, N, N] + let right = NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[3, 4, 5, 6]), + RowIdTreeMap::from_iter(&[4, 6, 7]), + ); + // [T, N, T, N, T, N, N] + let expected_true = RowIdTreeMap::from_iter(&[1, 3, 5]); + let expected_nulls = RowIdTreeMap::from_iter(&[2, 4, 6, 7]); + + let mut result = left.clone(); + result |= &right; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + // Commutative property holds + let mut result = right.clone(); + result |= &left; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + } + + #[test] + fn test_row_selection_bit_and() { + // [T, N, T, N, F, F, F] + let left = NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[1, 2, 3, 4]), + RowIdTreeMap::from_iter(&[2, 4]), + ); + // [F, F, T, N, T, N, N] + let right = NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[3, 4, 5, 6]), + RowIdTreeMap::from_iter(&[4, 6, 7]), + ); + // [F, F, T, N, F, F, F] + let expected_true = RowIdTreeMap::from_iter(&[3]); + let expected_nulls = RowIdTreeMap::from_iter(&[4]); + let mut result = left.clone(); + result &= &right; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + // Commutative property holds + let mut result = right.clone(); + result &= &left; + assert_eq!(&result.true_rows(), &expected_true); + assert_eq!(result.null_rows(), &expected_nulls); + } +} diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 69b5ee35cf..ee5c7d1e31 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -19,7 +19,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use deepsize::DeepSizeOf; use inverted::query::{fill_fts_query_column, FtsQuery, FtsQueryNode, FtsSearchParams, MatchQuery}; -use lance_core::utils::mask::RowIdTreeMap; +use lance_core::utils::mask::{NullableRowIdSet, RowIdTreeMap}; use lance_core::{Error, Result}; use serde::Serialize; use snafu::location; @@ -684,20 +684,40 @@ impl AnyQuery for TokenQuery { #[derive(Debug, PartialEq)] pub enum SearchResult { /// The exact row ids that satisfy the query - Exact(RowIdTreeMap), + Exact(NullableRowIdSet), /// Any row id satisfying the query will be in this set but not every /// row id in this set will satisfy the query, a further recheck step /// is needed - AtMost(RowIdTreeMap), + AtMost(NullableRowIdSet), /// All of the given row ids satisfy the query but there may be more /// /// No scalar index actually returns this today but it can arise from /// boolean operations (e.g. NOT(AtMost(x)) == AtLeast(NOT(x))) - AtLeast(RowIdTreeMap), + AtLeast(NullableRowIdSet), } impl SearchResult { - pub fn row_ids(&self) -> &RowIdTreeMap { + pub fn exact(row_ids: impl Into) -> Self { + Self::Exact(NullableRowIdSet::new(row_ids.into(), Default::default())) + } + + pub fn at_most(row_ids: impl Into) -> Self { + Self::AtMost(NullableRowIdSet::new(row_ids.into(), Default::default())) + } + + pub fn at_least(row_ids: impl Into) -> Self { + Self::AtLeast(NullableRowIdSet::new(row_ids.into(), Default::default())) + } + + pub fn with_nulls(self, nulls: impl Into) -> Self { + match self { + Self::Exact(row_ids) => Self::Exact(row_ids.with_nulls(nulls.into())), + Self::AtMost(row_ids) => Self::AtMost(row_ids.with_nulls(nulls.into())), + Self::AtLeast(row_ids) => Self::AtLeast(row_ids.with_nulls(nulls.into())), + } + } + + pub fn row_ids(&self) -> &NullableRowIdSet { match self { Self::Exact(row_ids) => row_ids, Self::AtMost(row_ids) => row_ids, diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 9f3779668f..94d1dadf6e 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -21,7 +21,7 @@ use futures::TryStreamExt; use lance_core::{ cache::{CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, - utils::mask::RowIdTreeMap, + utils::mask::{NullableRowIdSet, RowIdTreeMap}, Error, Result, ROW_ID, }; use roaring::RoaringBitmap; @@ -403,15 +403,21 @@ impl ScalarIndex for BitmapIndex { ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); - let row_ids = match query { + let (row_ids, null_row_ids) = match query { SargableQuery::Equals(val) => { metrics.record_comparisons(1); if val.is_null() { - (*self.null_map).clone() + // Querying FOR nulls - they are the TRUE result, not NULL result + ((*self.null_map).clone(), None) } else { let key = OrderableScalarValue(val.clone()); let bitmap = self.load_bitmap(&key, Some(metrics)).await?; - (*bitmap).clone() + let null_rows = if !self.null_map.is_empty() { + Some((*self.null_map).clone()) + } else { + None + }; + ((*bitmap).clone(), null_rows) } } SargableQuery::Range(start, end) => { @@ -435,7 +441,7 @@ impl ScalarIndex for BitmapIndex { metrics.record_comparisons(keys.len()); - if keys.is_empty() { + let result = if keys.is_empty() { RowIdTreeMap::default() } else { let mut bitmaps = Vec::new(); @@ -446,7 +452,14 @@ impl ScalarIndex for BitmapIndex { let bitmap_refs: Vec<_> = bitmaps.iter().map(|b| b.as_ref()).collect(); RowIdTreeMap::union_all(&bitmap_refs) - } + }; + + let null_rows = if !self.null_map.is_empty() { + Some((*self.null_map).clone()) + } else { + None + }; + (result, null_rows) } SargableQuery::IsIn(values) => { metrics.record_comparisons(values.len()); @@ -471,17 +484,27 @@ impl ScalarIndex for BitmapIndex { bitmaps.push(self.null_map.clone()); } - if bitmaps.is_empty() { + let result = if bitmaps.is_empty() { RowIdTreeMap::default() } else { // Convert Arc to &RowIdTreeMap for union_all let bitmap_refs: Vec<_> = bitmaps.iter().map(|b| b.as_ref()).collect(); RowIdTreeMap::union_all(&bitmap_refs) - } + }; + + // If the query explicitly includes null, then nulls are TRUE (not NULL) + // Otherwise, nulls remain NULL (unknown) + let null_rows = if !has_null && !self.null_map.is_empty() { + Some((*self.null_map).clone()) + } else { + None + }; + (result, null_rows) } SargableQuery::IsNull() => { metrics.record_comparisons(1); - (*self.null_map).clone() + // Querying FOR nulls - they are the TRUE result, not NULL result + ((*self.null_map).clone(), None) } SargableQuery::FullTextSearch(_) => { return Err(Error::NotSupported { @@ -491,7 +514,8 @@ impl ScalarIndex for BitmapIndex { } }; - Ok(SearchResult::Exact(row_ids)) + let selection = NullableRowIdSet::new(row_ids, null_row_ids.unwrap_or_default()); + Ok(SearchResult::Exact(selection)) } fn can_remap(&self) -> bool { @@ -766,7 +790,7 @@ pub mod tests { use super::*; use crate::metrics::NoOpMetricsCollector; use crate::scalar::lance_format::LanceIndexStore; - use arrow_array::{RecordBatch, StringArray, UInt64Array}; + use arrow_array::{record_batch, RecordBatch, StringArray, UInt64Array}; use arrow_schema::{Field, Schema}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::stream; @@ -831,7 +855,12 @@ pub mod tests { // Verify results let expected_red_rows = vec![0u64, 3, 6, 10, 11]; if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_ids().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_red_rows); } else { @@ -841,7 +870,12 @@ pub mod tests { // Test 2: Search for "red" again - should hit cache let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_ids().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_red_rows); } @@ -855,7 +889,12 @@ pub mod tests { let expected_range_rows = vec![1u64, 2, 5, 7, 8, 12, 13]; if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_ids().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_range_rows); } @@ -869,7 +908,12 @@ pub mod tests { let expected_in_rows = vec![0u64, 3, 4, 6, 9, 10, 11, 14]; if let SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_ids().unwrap().map(|id| id.into()).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(|id| id.into()) + .collect(); actual.sort(); assert_eq!(actual, expected_in_rows); } @@ -1240,7 +1284,12 @@ pub mod tests { .await .unwrap(); if let crate::scalar::SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_ids().unwrap().map(u64::from).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); actual.sort(); let expected: Vec = vec![ RowAddress::new_from_parts(3, 2).into(), @@ -1256,7 +1305,12 @@ pub mod tests { .await .unwrap(); if let crate::scalar::SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_ids().unwrap().map(u64::from).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); actual.sort(); let expected: Vec = vec![ RowAddress::new_from_parts(3, 4).into(), @@ -1272,7 +1326,12 @@ pub mod tests { .await .unwrap(); if let crate::scalar::SearchResult::Exact(row_ids) = result { - let mut actual: Vec = row_ids.row_ids().unwrap().map(u64::from).collect(); + let mut actual: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); actual.sort(); assert_eq!( actual, expected_null_addrs, @@ -1280,4 +1339,112 @@ pub mod tests { ); } } + + #[tokio::test] + async fn test_bitmap_null_handling_in_queries() { + // Test that bitmap index correctly returns null_list for queries + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [0, 5, null] + let batch = record_batch!( + ("value", Int64, [Some(0), Some(5), None]), + ("_rowid", UInt64, [0, 1, 2]) + ) + .unwrap(); + let schema = batch.schema(); + let stream = stream::once(async move { Ok(batch) }); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + + // Train and write the bitmap index + BitmapIndexPlugin::train_bitmap_index(stream, store.as_ref()) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = BitmapIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - should return allow=[1], null=[2] + let query = SargableQuery::Equals(ScalarValue::Int64(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![1], "Should find row 1 where value == 5"); + + let null_row_ids = row_ids.null_rows(); + // Check that null_row_ids contains row 2 + assert!(!null_row_ids.is_empty(), "null_row_ids should be Some"); + let null_rows: Vec = null_row_ids.row_ids().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![2], "Should report row 2 as null"); + } + _ => panic!("Expected Exact search result"), + } + + // Test 2: Search for null values - should return allow=[2], null=None + let query = SargableQuery::IsNull(); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + actual_rows, + vec![2], + "IsNull should find row 2 where value is null" + ); + + let null_row_ids = row_ids.null_rows(); + // When querying FOR nulls, null_row_ids should be None (nulls are the TRUE result) + assert!( + null_row_ids.is_empty(), + "null_row_ids should be None for IsNull query" + ); + } + _ => panic!("Expected Exact search result"), + } + + // Test 3: Range query - should return matching rows and null_list + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int64(Some(0))), + std::ops::Bound::Included(ScalarValue::Int64(Some(3))), + ); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![0], "Should find row 0 where value == 0"); + + // Should report row 2 as null + let null_row_ids = row_ids.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be Some"); + let null_rows: Vec = null_row_ids.row_ids().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![2], "Should report row 2 as null"); + } + _ => panic!("Expected Exact search result"), + } + } } diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index 7fef76136e..d7d151b960 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -479,7 +479,7 @@ impl ScalarIndex for BloomFilterIndex { } } - Ok(SearchResult::AtMost(row_id_tree_map)) + Ok(SearchResult::at_most(row_id_tree_map)) } fn can_remap(&self) -> bool { @@ -1292,7 +1292,7 @@ mod tests { use std::sync::Arc; use crate::scalar::bloomfilter::BloomFilterIndexPlugin; - use arrow_array::{RecordBatch, UInt64Array}; + use arrow_array::{record_batch, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -1376,7 +1376,7 @@ mod tests { // Equals query: null (should match nothing, as there are no nulls in empty index) let query = BloomFilterQuery::Equals(ScalarValue::Int32(None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } #[tokio::test] @@ -1433,14 +1433,14 @@ mod tests { // Should match the block since value 50 is in the range [0, 100) let mut expected = RowIdTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that shouldn't exist let query = BloomFilterQuery::Equals(ScalarValue::Int32(Some(500))); // Value not in [0, 100) let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty result since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // Test calculate_included_frags assert_eq!( @@ -1527,7 +1527,7 @@ mod tests { // Value 150 is only in fragment 1 (values 100-199), not in fragment 0 (values 0-99) let mut expected = RowIdTreeMap::new(); expected.insert_range((1u64 << 32) + 50..((1u64 << 32) + 100)); // Only the block containing 150 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test calculate_included_frags assert_eq!( @@ -1593,7 +1593,7 @@ mod tests { // Should match all blocks since they all contain NaN values let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); // All rows since NaN is in every block - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a specific finite value that exists in the data let query = BloomFilterQuery::Equals(ScalarValue::Float32(Some(5.0))); @@ -1602,7 +1602,7 @@ mod tests { // Should match only the first block since 5.0 only exists in rows 0-99 let mut expected = RowIdTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist but is within expected range let query = BloomFilterQuery::Equals(ScalarValue::Float32(Some(250.0))); @@ -1611,14 +1611,14 @@ mod tests { // Should match the third block since 250.0 would be in that range if it existed let mut expected = RowIdTreeMap::new(); expected.insert_range(200..300); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value way outside the range let query = BloomFilterQuery::Equals(ScalarValue::Float32(Some(10000.0))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // Test IsIn query with NaN and finite values let query = BloomFilterQuery::IsIn(vec![ @@ -1631,7 +1631,7 @@ mod tests { // Should match all blocks since they all contain NaN values let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1692,14 +1692,14 @@ mod tests { // Should match zone 2 let mut expected = RowIdTreeMap::new(); expected.insert_range(2000..3000); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value way outside the range let query = BloomFilterQuery::Equals(ScalarValue::Int64(Some(50000))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // Test IsIn query with values from different zones let query = BloomFilterQuery::IsIn(vec![ @@ -1715,7 +1715,7 @@ mod tests { expected.insert_range(0..1000); // Zone 0 expected.insert_range(2000..3000); // Zone 2 expected.insert_range(7000..8000); // Zone 7 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test calculate_included_frags assert_eq!( @@ -1771,7 +1771,7 @@ mod tests { // Should match the first zone let mut expected = RowIdTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value in the second zone let query = BloomFilterQuery::Equals(ScalarValue::Utf8(Some("value_150".to_string()))); @@ -1780,7 +1780,7 @@ mod tests { // Should match the second zone let mut expected = RowIdTreeMap::new(); expected.insert_range(100..200); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = @@ -1788,7 +1788,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // Test IsIn query with string values let query = BloomFilterQuery::IsIn(vec![ @@ -1801,7 +1801,7 @@ mod tests { // Should match both zones let mut expected = RowIdTreeMap::new(); expected.insert_range(0..200); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1853,7 +1853,7 @@ mod tests { // Should match the first zone let mut expected = RowIdTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value in the second zone let query = BloomFilterQuery::Equals(ScalarValue::Binary(Some(vec![75, 76, 77]))); @@ -1862,14 +1862,14 @@ mod tests { // Should match the second zone let mut expected = RowIdTreeMap::new(); expected.insert_range(50..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::Binary(Some(vec![255, 254, 253]))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } #[tokio::test] @@ -1922,7 +1922,7 @@ mod tests { // Should match the first zone let mut expected = RowIdTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::LargeUtf8(Some( @@ -1931,7 +1931,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should return empty since bloom filter correctly filters out this value - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } #[tokio::test] @@ -1978,19 +1978,19 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Date32 value in second zone let query = BloomFilterQuery::Equals(ScalarValue::Date32(Some(75))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(50..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Date32 value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::Date32(Some(500))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } #[tokio::test] @@ -2042,7 +2042,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(0..50); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Timestamp value in second zone let second_timestamp = timestamp_values[75]; @@ -2053,13 +2053,13 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(50..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Timestamp value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::TimestampNanosecond(Some(999_999_999i64), None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // Test IsIn query with multiple timestamp values let query = BloomFilterQuery::IsIn(vec![ @@ -2070,7 +2070,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(0..100); // Should match both zones - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -2121,12 +2121,12 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(0..25); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for Time64 value that doesn't exist let query = BloomFilterQuery::Equals(ScalarValue::Time64Microsecond(Some(999_999_999i64))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } #[tokio::test] @@ -2172,12 +2172,12 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(500..750); // Should match the zone containing 500 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test IsNull query let query = BloomFilterQuery::IsNull(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); // No nulls in the data + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // No nulls in the data // Test IsIn query let query = BloomFilterQuery::IsIn(vec![ @@ -2188,6 +2188,86 @@ mod tests { let mut expected = RowIdTreeMap::new(); expected.insert_range(0..250); // Zone containing 100 expected.insert_range(500..750); // Zone containing 600 - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); + } + + #[tokio::test] + async fn test_bloomfilter_null_handling_in_queries() { + // Test that bloomfilter index correctly returns null_list for queries + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [0, 5, null] + let batch = record_batch!( + (VALUE_COLUMN_NAME, Int64, [Some(0), Some(5), None]), + (ROW_ADDR, UInt64, [0, 1, 2]) + ) + .unwrap(); + let schema = batch.schema(); + let stream = stream::once(async move { Ok(batch) }); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + + // Train and write the bloomfilter index + BloomFilterIndexPlugin::train_bloomfilter_index(stream, store.as_ref(), None) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = BloomFilterIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - bloomfilter should return at_most with all rows + // Like ZoneMap, BloomFilter returns AtMost (superset) and includes nulls + let query = BloomFilterQuery::Equals(ScalarValue::Int64(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_ids) => { + // Bloomfilter returns all rows in the zone including nulls + let all_rows: Vec = row_ids + .selected_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows (including nulls) since BloomFilter is inexact" + ); + + // For AtMost results, nulls are included in the superset + } + _ => panic!("Expected AtMost search result from bloomfilter"), + } + + // Test 2: IsIn query - should also return all rows + let query = BloomFilterQuery::IsIn(vec![ + ScalarValue::Int64(Some(0)), + ScalarValue::Int64(Some(10)), + ]); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_ids) => { + let all_rows: Vec = row_ids + .selected_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows in zone as possible matches" + ); + } + _ => panic!("Expected AtMost search result from bloomfilter"), + } } } diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 8a4ea14c99..3db6250baf 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -44,7 +44,7 @@ use lance_core::{ cache::{CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, utils::{ - mask::RowIdTreeMap, + mask::NullableRowIdSet, tokio::get_num_compute_intensive_cpus, tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS}, }, @@ -832,7 +832,7 @@ impl BTreeIndex { page_number: u32, index_reader: LazyIndexReader, metrics: &dyn MetricsCollector, - ) -> Result { + ) -> Result { let subindex = self.lookup_page(page_number, index_reader, metrics).await?; // TODO: If this is an IN query we can perhaps simplify the subindex query by restricting it to the // values that might be in the page. E.g. if we are searching for X IN [5, 3, 7] and five is in pages @@ -1172,13 +1172,19 @@ impl ScalarIndex for BTreeIndex { }) .collect::>(); debug!("Searching {} btree pages", page_tasks.len()); - let row_ids = stream::iter(page_tasks) + + // Collect both matching row IDs and null row IDs from all pages + let results: Vec = stream::iter(page_tasks) // I/O and compute mixed here but important case is index in cache so // use compute intensive thread count .buffered(get_num_compute_intensive_cpus()) - .try_collect::() + .try_collect() .await?; - Ok(SearchResult::Exact(row_ids)) + + // Merge matching row IDs + let selection = NullableRowIdSet::union_all(&results); + + Ok(SearchResult::Exact(selection)) } fn can_remap(&self) -> bool { @@ -1999,7 +2005,7 @@ mod tests { use std::{collections::HashMap, sync::Arc}; use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; - use arrow_array::FixedSizeListArray; + use arrow_array::{record_batch, FixedSizeListArray}; use arrow_schema::DataType; use datafusion::{ execution::{SendableRecordBatchStream, TaskContext}, @@ -2008,12 +2014,14 @@ mod tests { use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; use deepsize::DeepSizeOf; + use futures::stream; use futures::TryStreamExt; use lance_core::utils::tempfile::TempObjDir; use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap}; use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; use lance_datagen::{array, gen_batch, ArrayGeneratorExt, BatchCount, RowCount}; use lance_io::object_store::ObjectStore; + use object_store::path::Path; use crate::metrics::LocalMetricsCollector; use crate::{ @@ -2161,7 +2169,7 @@ mod tests { let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert_eq!( result, - SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + SearchResult::exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) ); } } @@ -2868,4 +2876,114 @@ mod tests { // This test mainly verifies that the function doesn't panic and handles edge cases super::cleanup_partition_files(&test_store, &lookup_files, &page_files).await; } + + #[tokio::test] + async fn test_btree_null_handling_in_queries() { + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::memory()), + Path::default(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [null, 0, 5] at row IDs [0, 1, 2] + // BTree expects sorted data with nulls first (or filtered out) + let batch = record_batch!( + ("value", Int32, [None, Some(0), Some(5)]), + ("_rowid", UInt64, [0, 1, 2]) + ) + .unwrap(); + let stream = stream::once(futures::future::ok(batch.clone())); + let stream = Box::pin(RecordBatchStreamAdapter::new(batch.schema(), stream)); + + // Train the btree index with FlatIndexMetadata as sub-index + let sub_index_trainer = super::FlatIndexMetadata::new(DataType::Int32); + super::train_btree_index(stream, &sub_index_trainer, store.as_ref(), 256, None) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = super::BTreeIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - should return allow=[2], null=[0] + let query = SargableQuery::Equals(ScalarValue::Int32(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![2], "Should find row 2 where value == 5"); + + // Check that null_row_ids contains row 0 + let null_row_ids = row_ids.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty"); + let null_rows: Vec = null_row_ids.row_ids().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![0], "Should report row 0 as null"); + } + _ => panic!("Expected Exact search result"), + } + + // Test 2: Range query [0, 3] - should return allow=[1], null=[0] + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int32(Some(0))), + std::ops::Bound::Included(ScalarValue::Int32(Some(3))), + ); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!(actual_rows, vec![1], "Should find row 1 where value == 0"); + + // Should report row 0 as null + let null_row_ids = row_ids.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty"); + let null_rows: Vec = null_row_ids.row_ids().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![0], "Should report row 0 as null"); + } + _ => panic!("Expected Exact search result"), + } + + // Test 3: IsIn query [0, 5] - should return allow=[1, 2], null=[0] + let query = SargableQuery::IsIn(vec![ + ScalarValue::Int32(Some(0)), + ScalarValue::Int32(Some(5)), + ]); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let mut actual_rows: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + actual_rows.sort(); + assert_eq!( + actual_rows, + vec![1, 2], + "Should find rows 1 and 2 where value in [0, 5]" + ); + + // Should report row 0 as null + let null_row_ids = row_ids.null_rows(); + assert!(!null_row_ids.is_empty(), "null_row_ids should be non-empty"); + let null_rows: Vec = null_row_ids.row_ids().unwrap().map(u64::from).collect(); + assert_eq!(null_rows, vec![0], "Should report row 0 as null"); + } + _ => panic!("Expected Exact search result"), + } + } } diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 2e867bc9de..13f6585d2d 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -16,13 +16,16 @@ use datafusion_expr::{ expr::{InList, ScalarFunction}, Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF, }; +use tokio::try_join; use super::{ AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery, TokenQuery, }; -use futures::join; -use lance_core::{utils::mask::RowIdMask, Error, Result}; +use lance_core::{ + utils::mask::{NullableRowIdMask, RowIdMask}, + Error, Result, +}; use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; use roaring::RoaringBitmap; use snafu::location; @@ -902,6 +905,81 @@ pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock = LazyLock::new(|| { ])) }); +#[derive(Debug)] +enum NullableIndexExprResult { + Exact(NullableRowIdMask), + AtMost(NullableRowIdMask), + AtLeast(NullableRowIdMask), +} + +impl From for NullableIndexExprResult { + fn from(result: SearchResult) -> Self { + match result { + SearchResult::Exact(mask) => Self::Exact(NullableRowIdMask::AllowList(mask)), + SearchResult::AtMost(mask) => Self::AtMost(NullableRowIdMask::AllowList(mask)), + SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowIdMask::AllowList(mask)), + } + } +} + +impl std::ops::BitAnd for NullableIndexExprResult { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + match (self, rhs) { + (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs & rhs), + (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(lhs), Self::Exact(rhs)) => { + Self::AtMost(lhs & rhs) + } + (Self::Exact(exact), Self::AtLeast(_)) | (Self::AtLeast(_), Self::Exact(exact)) => { + // We could do better here, elements in both lhs and rhs are known + // to be true and don't require a recheck. We only need to recheck + // elements in lhs that are not in rhs + Self::AtMost(exact) + } + (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs & rhs), + (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs & rhs), + (Self::AtMost(most), Self::AtLeast(_)) | (Self::AtLeast(_), Self::AtMost(most)) => { + Self::AtMost(most) + } + } + } +} + +impl std::ops::BitOr for NullableIndexExprResult { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self { + match (self, rhs) { + (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs | rhs), + (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(rhs), Self::Exact(lhs)) => { + // We could do better here, elements in lhs are known to be true + // and don't require a recheck. We only need to recheck elements + // in rhs that are not in lhs + Self::AtMost(lhs | rhs) + } + (Self::Exact(lhs), Self::AtLeast(rhs)) | (Self::AtLeast(rhs), Self::Exact(lhs)) => { + Self::AtLeast(lhs | rhs) + } + (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs | rhs), + (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs | rhs), + (Self::AtMost(_), Self::AtLeast(least)) | (Self::AtLeast(least), Self::AtMost(_)) => { + Self::AtLeast(least) + } + } + } +} + +impl NullableIndexExprResult { + pub fn drop_nulls(self) -> IndexExprResult { + match self { + Self::Exact(mask) => IndexExprResult::Exact(mask.drop_nulls()), + Self::AtMost(mask) => IndexExprResult::AtMost(mask.drop_nulls()), + Self::AtLeast(mask) => IndexExprResult::AtLeast(mask.drop_nulls()), + } + } +} + #[derive(Debug)] pub enum IndexExprResult { // The answer is exactly the rows in the allow list minus the rows in the block list @@ -981,117 +1059,59 @@ impl ScalarIndexExpr { /// TODO: We could potentially try and be smarter about reusing loaded indices for /// any situations where the session cache has been disabled. #[async_recursion] - #[instrument(level = "debug", skip_all)] - pub async fn evaluate( + async fn evaluate_impl( &self, index_loader: &dyn ScalarIndexLoader, metrics: &dyn MetricsCollector, - ) -> Result { + ) -> Result { match self { Self::Not(inner) => { - let result = inner.evaluate(index_loader, metrics).await?; - match result { - IndexExprResult::Exact(mask) => Ok(IndexExprResult::Exact(!mask)), - IndexExprResult::AtMost(mask) => Ok(IndexExprResult::AtLeast(!mask)), - IndexExprResult::AtLeast(mask) => Ok(IndexExprResult::AtMost(!mask)), - } - } - Self::And(lhs, rhs) => { - let lhs_result = lhs.evaluate(index_loader, metrics); - let rhs_result = rhs.evaluate(index_loader, metrics); - let (lhs_result, rhs_result) = join!(lhs_result, rhs_result); - match (lhs_result?, rhs_result?) { - (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::Exact(lhs & rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs)) - | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::AtMost(lhs & rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(_)) => { - // We could do better here, elements in both lhs and rhs are known - // to be true and don't require a recheck. We only need to recheck - // elements in lhs that are not in rhs - Ok(IndexExprResult::AtMost(lhs)) - } - (IndexExprResult::AtLeast(_), IndexExprResult::Exact(rhs)) => { - // We could do better here (see above) - Ok(IndexExprResult::AtMost(rhs)) - } - (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => { - Ok(IndexExprResult::AtMost(lhs & rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs & rhs)) + let result = inner.evaluate_impl(index_loader, metrics).await?; + // Flip certainty: NOT(AtMost) → AtLeast, NOT(AtLeast) → AtMost + Ok(match result { + NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), + NullableIndexExprResult::AtMost(mask) => { + NullableIndexExprResult::AtLeast(!mask) } - (IndexExprResult::AtLeast(_), IndexExprResult::AtMost(rhs)) => { - Ok(IndexExprResult::AtMost(rhs)) + NullableIndexExprResult::AtLeast(mask) => { + NullableIndexExprResult::AtMost(!mask) } - (IndexExprResult::AtMost(lhs), IndexExprResult::AtLeast(_)) => { - Ok(IndexExprResult::AtMost(lhs)) - } - } + }) + } + Self::And(lhs, rhs) => { + let lhs_result = lhs.evaluate_impl(index_loader, metrics); + let rhs_result = rhs.evaluate_impl(index_loader, metrics); + let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; + Ok(lhs_result & rhs_result) } Self::Or(lhs, rhs) => { - let lhs_result = lhs.evaluate(index_loader, metrics); - let rhs_result = rhs.evaluate(index_loader, metrics); - let (lhs_result, rhs_result) = join!(lhs_result, rhs_result); - match (lhs_result?, rhs_result?) { - (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::Exact(lhs | rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs)) - | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => { - // We could do better here. Elements in the exact side don't need - // re-check. We only need to recheck elements exclusively in the - // at-most side - Ok(IndexExprResult::AtMost(lhs | rhs)) - } - (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs | rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::Exact(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs | rhs)) - } - (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => { - Ok(IndexExprResult::AtMost(lhs | rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(lhs | rhs)) - } - (IndexExprResult::AtLeast(lhs), IndexExprResult::AtMost(_)) => { - Ok(IndexExprResult::AtLeast(lhs)) - } - (IndexExprResult::AtMost(_), IndexExprResult::AtLeast(rhs)) => { - Ok(IndexExprResult::AtLeast(rhs)) - } - } + let lhs_result = lhs.evaluate_impl(index_loader, metrics); + let rhs_result = rhs.evaluate_impl(index_loader, metrics); + let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; + Ok(lhs_result | rhs_result) } Self::Query(search) => { let index = index_loader .load_index(&search.column, &search.index_name, metrics) .await?; let search_result = index.search(search.query.as_ref(), metrics).await?; - match search_result { - SearchResult::Exact(matching_row_ids) => { - Ok(IndexExprResult::Exact(RowIdMask { - block_list: None, - allow_list: Some(matching_row_ids), - })) - } - SearchResult::AtMost(row_ids) => Ok(IndexExprResult::AtMost(RowIdMask { - block_list: None, - allow_list: Some(row_ids), - })), - SearchResult::AtLeast(row_ids) => Ok(IndexExprResult::AtLeast(RowIdMask { - block_list: None, - allow_list: Some(row_ids), - })), - } + Ok(search_result.into()) } } } + #[instrument(level = "debug", skip_all)] + pub async fn evaluate( + &self, + index_loader: &dyn ScalarIndexLoader, + metrics: &dyn MetricsCollector, + ) -> Result { + Ok(self + .evaluate_impl(index_loader, metrics) + .await? + .drop_nulls()) + } + pub fn to_expr(&self) -> Expr { match self { Self::Not(inner) => Expr::Not(inner.to_expr().into()), @@ -2175,4 +2195,125 @@ mod tests { check_no_index(&index_info, "aisle BETWEEN 5 AND NULL"); check_no_index(&index_info, "aisle BETWEEN NULL AND 10"); } + + #[tokio::test] + async fn test_not_flips_certainty() { + use lance_core::utils::mask::{NullableRowIdSet, RowIdTreeMap}; + + // Test that NOT flips certainty for inexact index results + // This tests the implementation in evaluate_impl for Self::Not + + // Helper function that mimics the NOT logic we just fixed + fn apply_not(result: NullableIndexExprResult) -> NullableIndexExprResult { + match result { + NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), + NullableIndexExprResult::AtMost(mask) => NullableIndexExprResult::AtLeast(!mask), + NullableIndexExprResult::AtLeast(mask) => NullableIndexExprResult::AtMost(!mask), + } + } + + // AtMost: superset of matches (e.g., bloom filter says "might be in [1,2]") + let at_most = NullableIndexExprResult::AtMost(NullableRowIdMask::AllowList( + NullableRowIdSet::new(RowIdTreeMap::from_iter(&[1, 2]), RowIdTreeMap::new()), + )); + // NOT(AtMost) should be AtLeast (definitely NOT in [1,2], might be elsewhere) + assert!(matches!( + apply_not(at_most), + NullableIndexExprResult::AtLeast(_) + )); + + // AtLeast: subset of matches (e.g., definitely in [1,2], might be more) + let at_least = NullableIndexExprResult::AtLeast(NullableRowIdMask::AllowList( + NullableRowIdSet::new(RowIdTreeMap::from_iter(&[1, 2]), RowIdTreeMap::new()), + )); + // NOT(AtLeast) should be AtMost (might NOT be in [1,2], definitely elsewhere) + assert!(matches!( + apply_not(at_least), + NullableIndexExprResult::AtMost(_) + )); + + // Exact should stay Exact + let exact = NullableIndexExprResult::Exact(NullableRowIdMask::AllowList( + NullableRowIdSet::new(RowIdTreeMap::from_iter(&[1, 2]), RowIdTreeMap::new()), + )); + assert!(matches!( + apply_not(exact), + NullableIndexExprResult::Exact(_) + )); + } + + #[tokio::test] + async fn test_and_or_preserve_certainty() { + use lance_core::utils::mask::{NullableRowIdSet, RowIdTreeMap}; + + // Test that AND/OR correctly propagate certainty + let make_at_most = || { + NullableIndexExprResult::AtMost(NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[1, 2, 3]), + RowIdTreeMap::new(), + ))) + }; + + let make_at_least = || { + NullableIndexExprResult::AtLeast(NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[2, 3, 4]), + RowIdTreeMap::new(), + ))) + }; + + let make_exact = || { + NullableIndexExprResult::Exact(NullableRowIdMask::AllowList(NullableRowIdSet::new( + RowIdTreeMap::from_iter(&[1, 2]), + RowIdTreeMap::new(), + ))) + }; + + // AtMost & AtMost → AtMost + assert!(matches!( + make_at_most() & make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // AtLeast & AtLeast → AtLeast + assert!(matches!( + make_at_least() & make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // AtMost & AtLeast → AtMost (superset remains superset) + assert!(matches!( + make_at_most() & make_at_least(), + NullableIndexExprResult::AtMost(_) + )); + + // AtMost | AtMost → AtMost + assert!(matches!( + make_at_most() | make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // AtLeast | AtLeast → AtLeast + assert!(matches!( + make_at_least() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // AtMost | AtLeast → AtLeast (subset coverage guaranteed) + assert!(matches!( + make_at_most() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // Exact & AtMost → AtMost + assert!(matches!( + make_exact() & make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // Exact | AtLeast → AtLeast + assert!(matches!( + make_exact() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + } } diff --git a/rust/lance-index/src/scalar/flat.rs b/rust/lance-index/src/scalar/flat.rs index 99fb263921..4f4dd66afe 100644 --- a/rust/lance-index/src/scalar/flat.rs +++ b/rust/lance-index/src/scalar/flat.rs @@ -15,7 +15,7 @@ use datafusion_physical_expr::expressions::{in_list, lit, Column}; use deepsize::DeepSizeOf; use lance_core::error::LanceOptionExt; use lance_core::utils::address::RowAddress; -use lance_core::utils::mask::RowIdTreeMap; +use lance_core::utils::mask::{NullableRowIdSet, RowIdTreeMap}; use lance_core::{Error, Result, ROW_ID}; use roaring::RoaringBitmap; use snafu::location; @@ -299,14 +299,37 @@ impl ScalarIndex for FlatIndex { let valid_values = arrow::compute::is_not_null(self.values())?; predicate = arrow::compute::and(&valid_values, &predicate)?; } + + // Track null row IDs for Kleene logic + // When querying FOR nulls (IS NULL or Equals(null)), don't track them as "null results" + // because they are the TRUE result of the query + let null_row_ids = if self.has_nulls + && !matches!(query, SargableQuery::IsNull()) + && !matches!(query, SargableQuery::Equals(val) if val.is_null()) + { + let null_mask = arrow::compute::is_null(self.values())?; + let null_ids = arrow_select::filter::filter(self.ids(), &null_mask)?; + let null_ids = null_ids + .as_any() + .downcast_ref::() + .expect("Result of arrow_select::filter::filter did not match input type"); + if null_ids.is_empty() { + None + } else { + Some(RowIdTreeMap::from_iter(null_ids.values())) + } + } else { + None + }; + let matching_ids = arrow_select::filter::filter(self.ids(), &predicate)?; let matching_ids = matching_ids .as_any() .downcast_ref::() .expect("Result of arrow_select::filter::filter did not match input type"); - Ok(SearchResult::Exact(RowIdTreeMap::from_iter( - matching_ids.values(), - ))) + let selected = RowIdTreeMap::from_iter(matching_ids.values()); + let selection = NullableRowIdSet::new(selected, null_row_ids.unwrap_or_default()); + Ok(SearchResult::Exact(selection)) } fn can_remap(&self) -> bool { @@ -372,7 +395,7 @@ mod tests { let SearchResult::Exact(actual_row_ids) = actual else { panic! {"Expected exact search result"} }; - let expected = RowIdTreeMap::from_iter(expected); + let expected = NullableRowIdSet::new(RowIdTreeMap::from_iter(expected), Default::default()); assert_eq!(actual_row_ids, expected); } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 7a5afc4bf4..48f1c64fe3 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -533,7 +533,7 @@ impl ScalarIndex for InvertedIndex { .downcast_ref::() .unwrap(); let row_ids = row_ids.iter().flatten().collect_vec(); - Ok(SearchResult::AtMost(RowIdTreeMap::from_iter(row_ids))) + Ok(SearchResult::at_most(RowIdTreeMap::from_iter(row_ids))) } } } diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index b22a12f8e4..037f107a50 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -13,7 +13,8 @@ use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; use futures::{stream::BoxStream, StreamExt, TryStream, TryStreamExt}; use lance_core::cache::LanceCache; -use lance_core::{utils::mask::RowIdTreeMap, Error, Result}; +use lance_core::utils::mask::NullableRowIdSet; +use lance_core::{Error, Result}; use roaring::RoaringBitmap; use snafu::location; use tracing::instrument; @@ -41,7 +42,7 @@ trait LabelListSubIndex: ScalarIndex + DeepSizeOf { &self, query: &dyn AnyQuery, metrics: &dyn MetricsCollector, - ) -> Result { + ) -> Result { let result = self.search(query, metrics).await?; match result { SearchResult::Exact(row_ids) => Ok(row_ids), @@ -118,7 +119,7 @@ impl LabelListIndex { &'a self, values: &'a Vec, metrics: &'a dyn MetricsCollector, - ) -> BoxStream<'a, Result> { + ) -> BoxStream<'a, Result> { futures::stream::iter(values) .then(move |value| { let value_query = SargableQuery::Equals(value.clone()); @@ -129,24 +130,24 @@ impl LabelListIndex { async fn set_union<'a>( &'a self, - mut sets: impl TryStream + 'a + Unpin, + mut sets: impl TryStream + 'a + Unpin, single_set: bool, - ) -> Result { + ) -> Result { let mut union_bitmap = sets.try_next().await?.unwrap(); if single_set { return Ok(union_bitmap); } while let Some(next) = sets.try_next().await? { - union_bitmap |= next; + union_bitmap |= &next; } Ok(union_bitmap) } async fn set_intersection<'a>( &'a self, - mut sets: impl TryStream + 'a + Unpin, + mut sets: impl TryStream + 'a + Unpin, single_set: bool, - ) -> Result { + ) -> Result { let mut intersect_bitmap = sets.try_next().await?.unwrap(); if single_set { return Ok(intersect_bitmap); diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index 2d6703bbf0..5108c8d8a6 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -313,7 +313,7 @@ pub mod tests { bitmap::BitmapIndex, btree::{train_btree_index, DEFAULT_BTREE_BATCH_SIZE}, flat::FlatIndexMetadata, - LabelListQuery, SargableQuery, ScalarIndex, + LabelListQuery, SargableQuery, ScalarIndex, SearchResult, }; use super::*; @@ -321,7 +321,7 @@ pub mod tests { use arrow_array::{ cast::AsArray, types::{Int32Type, UInt64Type}, - RecordBatchIterator, RecordBatchReader, StringArray, UInt64Array, + ListArray, RecordBatchIterator, RecordBatchReader, StringArray, UInt64Array, }; use arrow_schema::Schema as ArrowSchema; use arrow_schema::{DataType, Field, TimeUnit}; @@ -402,7 +402,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(10000)); @@ -418,7 +418,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(0), row_ids.len()); @@ -434,7 +434,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(100), row_ids.len()); } @@ -494,7 +494,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(10000)); @@ -508,7 +508,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(500_000)); @@ -518,7 +518,7 @@ pub mod tests { let results = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert!(results.is_exact()); let expected_arr = RowIdTreeMap::from_iter(expected); - assert_eq!(results.row_ids(), &expected_arr); + assert_eq!(&results.row_ids().true_rows(), &expected_arr); } #[tokio::test] @@ -823,7 +823,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); // The random data may have had duplicates so there might be more than 1 result // but even for boolean we shouldn't match the entire thing @@ -886,7 +886,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert!(row_ids.is_empty()); @@ -895,7 +895,7 @@ pub mod tests { .await .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(row_ids.len(), Some(4096)); } @@ -962,7 +962,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(2)); @@ -975,7 +975,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(3), row_ids.len()); assert!(row_ids.contains(1)); assert!(row_ids.contains(3)); @@ -1004,7 +1004,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(10000)); @@ -1020,7 +1020,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert!(row_ids.is_empty()); let result = index @@ -1035,7 +1035,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(100), row_ids.len()); } @@ -1043,7 +1043,7 @@ pub mod tests { let results = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert!(results.is_exact()); let expected_arr = RowIdTreeMap::from_iter(expected); - assert_eq!(results.row_ids(), &expected_arr); + assert_eq!(&results.row_ids().true_rows(), &expected_arr); } #[tokio::test] @@ -1307,7 +1307,7 @@ pub mod tests { .unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(5000)); } @@ -1357,7 +1357,7 @@ pub mod tests { .await .unwrap() .row_ids() - .contains(65)); + .selected(65)); // Deleted assert!(remapped_index .search( @@ -1377,7 +1377,7 @@ pub mod tests { .await .unwrap() .row_ids() - .contains(3)); + .selected(3)); } async fn train_tag( @@ -1442,7 +1442,7 @@ pub mod tests { .unwrap(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); assert!(result.is_exact()); - let row_ids = result.row_ids(); + let row_ids = result.row_ids().true_rows(); let row_ids_set = row_ids .row_ids() @@ -1506,4 +1506,84 @@ pub mod tests { ) .await; } + + #[tokio::test] + async fn test_label_list_null_handling() { + let tempdir = TempDir::default(); + let index_store = test_store(&tempdir); + + // Create test data with null items within lists: + // Row 0: [1, 2] - no nulls + // Row 1: [3, null] - has a null item + // Row 2: [4] - no nulls + let list_array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), None]), + Some(vec![Some(4)]), + ]); + let row_ids = UInt64Array::from_iter_values(0..3); + // Create schema with nullable list items to match the ListArray + let schema = Arc::new(Schema::new(vec![ + Field::new( + VALUE_COLUMN_NAME, + DataType::List(Arc::new(Field::new("item", DataType::UInt8, true))), + true, + ), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(list_array), Arc::new(row_ids)], + ) + .unwrap(); + + let batch_reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + train_tag(&index_store, batch_reader).await; + + let index = LabelListIndexPlugin + .load_index( + index_store, + &default_details::(), + None, + &LanceCache::no_cache(), + ) + .await + .unwrap(); + + // Test: Search for lists containing value 1 + // Row 0: [1, 2] - contains 1 → TRUE + // Row 1: [3, null] - has null item, unknown if it matches → NULL + // Row 2: [4] - doesn't contain 1 → FALSE + let query = LabelListQuery::HasAnyLabel(vec![ScalarValue::UInt8(Some(1))]); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::Exact(row_ids) => { + let actual_rows: Vec = row_ids + .true_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + actual_rows, + vec![0], + "Should find row 0 where list contains 1" + ); + + let null_row_ids = row_ids.null_rows(); + assert!( + !null_row_ids.is_empty(), + "null_row_ids should not be empty - row 1 has null item" + ); + let null_rows: Vec = null_row_ids.row_ids().unwrap().map(u64::from).collect(); + assert_eq!( + null_rows, + vec![1], + "Should report row 1 as null because it contains a null item" + ); + } + _ => panic!("Expected Exact search result"), + } + } } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index 00a2f7da5d..872c0bad75 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -451,7 +451,7 @@ impl ScalarIndex for NGramIndex { TextQuery::StringContains(substr) => { if substr.len() < NGRAM_N { // We know nothing on short searches, need to recheck all - return Ok(SearchResult::AtLeast(RowIdTreeMap::new())); + return Ok(SearchResult::at_least(RowIdTreeMap::new())); } let mut row_offsets = Vec::with_capacity(substr.len() * 3); @@ -466,7 +466,7 @@ impl ScalarIndex for NGramIndex { }); // At least one token was missing, so we know there are zero results if missing { - return Ok(SearchResult::Exact(RowIdTreeMap::new())); + return Ok(SearchResult::exact(RowIdTreeMap::new())); } let posting_lists = futures::stream::iter( row_offsets @@ -479,7 +479,7 @@ impl ScalarIndex for NGramIndex { metrics.record_comparisons(posting_lists.len()); let list_refs = posting_lists.iter().map(|list| list.as_ref()); let row_ids = NGramPostingList::intersect(list_refs); - Ok(SearchResult::AtMost(RowIdTreeMap::from(row_ids))) + Ok(SearchResult::at_most(RowIdTreeMap::from(row_ids))) } } } @@ -1483,7 +1483,7 @@ mod tests { .await .unwrap(); - let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([0, 2, 3])); + let expected = SearchResult::at_most(RowIdTreeMap::from_iter([0, 2, 3])); assert_eq!(expected, res); @@ -1495,7 +1495,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([8])); + let expected = SearchResult::at_most(RowIdTreeMap::from_iter([8])); assert_eq!(expected, res); // No matches @@ -1506,7 +1506,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::Exact(RowIdTreeMap::new()); + let expected = SearchResult::exact(RowIdTreeMap::new()); assert_eq!(expected, res); // False positive @@ -1517,7 +1517,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([8])); + let expected = SearchResult::at_most(RowIdTreeMap::from_iter([8])); assert_eq!(expected, res); // Too short, don't know anything @@ -1528,7 +1528,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtLeast(RowIdTreeMap::new()); + let expected = SearchResult::at_least(RowIdTreeMap::new()); assert_eq!(expected, res); // One short string but we still get at least one trigram, this is ok @@ -1539,7 +1539,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([8])); + let expected = SearchResult::at_most(RowIdTreeMap::from_iter([8])); assert_eq!(expected, res); } @@ -1578,7 +1578,7 @@ mod tests { ) .await .unwrap(); - let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([0, 4])); + let expected = SearchResult::at_most(RowIdTreeMap::from_iter([0, 4])); assert_eq!(expected, res); let null_posting_list = get_null_posting_list(&index).await; diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 7b6e607831..095b8b449c 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -552,7 +552,7 @@ impl ScalarIndex for ZoneMapIndex { } } - Ok(SearchResult::AtMost(row_id_tree_map)) + Ok(SearchResult::at_most(row_id_tree_map)) } fn can_remap(&self) -> bool { @@ -1028,12 +1028,13 @@ mod tests { use crate::scalar::zonemap::{ZoneMapIndexPlugin, ZoneMapStatistics}; use arrow::datatypes::Float32Type; - use arrow_array::{Array, RecordBatch, UInt64Array}; + use arrow_array::{record_batch, Array, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion_common::ScalarValue; use futures::{stream, StreamExt, TryStreamExt}; + use lance_core::utils::mask::NullableRowIdSet; use lance_core::utils::tempfile::TempObjDir; use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap, ROW_ADDR}; use lance_datafusion::datagen::DatafusionDatagenExt; @@ -1116,7 +1117,7 @@ mod tests { // Equals query: null (should match nothing, as there are no nulls) let query = SargableQuery::Equals(ScalarValue::Int32(None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } #[tokio::test] @@ -1172,7 +1173,7 @@ mod tests { let end = start + 5000; expected.insert_range(start..end); } - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test update - add new data with Float32 values (matching the original data type) let new_data = @@ -1230,7 +1231,7 @@ mod tests { let end = start + 5000; expected.insert_range(start..end); } - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that should be in the new zone let query = SargableQuery::Equals(ScalarValue::Float32(Some(2.5))); // Value 2500/1000 = 2.5 @@ -1244,7 +1245,90 @@ mod tests { let start = 10u64 << 32; let end = start + 5000; expected.insert_range(start..end); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); + } + + #[tokio::test] + async fn test_zonemap_null_handling_in_queries() { + // Test that zonemap index correctly returns null_list for queries + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // Create test data: [0, 5, null] + let batch = record_batch!( + (VALUE_COLUMN_NAME, Int64, [Some(0), Some(5), None]), + (ROW_ADDR, UInt64, [0, 1, 2]) + ) + .unwrap(); + let schema = batch.schema(); + let stream = stream::once(async move { Ok(batch) }); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + + // Train and write the zonemap index + ZoneMapIndexPlugin::train_zonemap_index(stream, store.as_ref(), None) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = ZoneMapIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + + // Test 1: Search for value 5 - zonemap should return at_most with all rows + // Since ZoneMap returns AtMost (superset), it's correct to include nulls in the result + let query = SargableQuery::Equals(ScalarValue::Int64(Some(5))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_ids) => { + // Zonemap can't determine exact matches, so it returns all rows in the zone + // This includes nulls because ZoneMap can't prove they don't match + let all_rows: Vec = row_ids + .selected_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows (including nulls) since ZoneMap is inexact" + ); + + // For AtMost results, nulls are included in the superset + // Downstream processing will handle null filtering + } + _ => panic!("Expected AtMost search result from zonemap"), + } + + // Test 2: Range query - should also return all rows as AtMost + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int64(Some(0))), + std::ops::Bound::Included(ScalarValue::Int64(Some(3))), + ); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + + match result { + SearchResult::AtMost(row_ids) => { + // Again, ZoneMap returns superset including nulls + let all_rows: Vec = row_ids + .selected_rows() + .row_ids() + .unwrap() + .map(u64::from) + .collect(); + assert_eq!( + all_rows, + vec![0, 1, 2], + "Should return all rows in zone as possible matches" + ); + } + _ => panic!("Expected AtMost search result from zonemap"), + } } #[tokio::test] @@ -1320,7 +1404,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); // All rows since NaN is in every zone - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a specific finite value that exists in the data let query = SargableQuery::Equals(ScalarValue::Float32(Some(5.0))); @@ -1329,7 +1413,7 @@ mod tests { // Should match only the first zone since 5.0 only exists in rows 0-99 let mut expected = RowIdTreeMap::new(); expected.insert_range(0..100); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test search for a value that doesn't exist let query = SargableQuery::Equals(ScalarValue::Float32(Some(1000.0))); @@ -1339,7 +1423,7 @@ mod tests { // as potential matches for any finite target (false positive, but acceptable for zone maps) let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test range query that should include finite values let query = SargableQuery::Range( @@ -1351,7 +1435,7 @@ mod tests { // Should match the first three zones since they contain values in the range [0, 250] let mut expected = RowIdTreeMap::new(); expected.insert_range(0..300); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test IsIn query with NaN and finite values let query = SargableQuery::IsIn(vec![ @@ -1364,7 +1448,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test range query that excludes all values let query = SargableQuery::Range( @@ -1377,12 +1461,12 @@ mod tests { // as potential matches for any range query (false positive, but acceptable for zone maps) let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test IsNull query (should match nothing since there are no null values) let query = SargableQuery::IsNull(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::AtMost(NullableRowIdSet::empty())); // Test range queries with NaN bounds // Range with NaN as start bound (included) @@ -1394,7 +1478,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Range with NaN as end bound (included) let query = SargableQuery::Range( @@ -1405,7 +1489,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Range with NaN as end bound (excluded) let query = SargableQuery::Range( @@ -1416,7 +1500,7 @@ mod tests { // Should match all zones since everything is less than NaN let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Range with NaN as start bound (excluded) let query = SargableQuery::Range( @@ -1425,7 +1509,7 @@ mod tests { ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Should match nothing since nothing is greater than NaN - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::AtMost(NullableRowIdSet::empty())); // Test IsIn query with mixed float types (Float16, Float32, Float64) let query = SargableQuery::IsIn(vec![ @@ -1438,7 +1522,7 @@ mod tests { // Should match all zones since they all contain NaN values let mut expected = RowIdTreeMap::new(); expected.insert_range(0..500); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1560,10 +1644,7 @@ mod tests { Bound::Unbounded, ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowIdTreeMap::from_iter(0..=100)) - ); + assert_eq!(result, SearchResult::at_most(0..=100)); // 2. Range query: [0, 50] let query = SargableQuery::Range( @@ -1571,10 +1652,7 @@ mod tests { Bound::Included(ScalarValue::Int32(Some(50))), ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowIdTreeMap::from_iter(0..=99)) - ); + assert_eq!(result, SearchResult::at_most(0..=99)); // 3. Range query: [101, 200] (should only match the second zone, which is row 100) let query = SargableQuery::Range( @@ -1583,7 +1661,7 @@ mod tests { ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // Only row 100 is in the second zone, but its value is 100, so this should be empty - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // 4. Range query: [100, 100] (should match only the last row) let query = SargableQuery::Range( @@ -1591,37 +1669,27 @@ mod tests { Bound::Included(ScalarValue::Int32(Some(100))), ); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowIdTreeMap::from_iter(100..=100)) - ); + assert_eq!(result, SearchResult::at_most(100..=100)); // 5. Equals query: 0 (should match first row) let query = SargableQuery::Equals(ScalarValue::Int32(Some(0))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowIdTreeMap::from_iter(0..100)) - ); + assert_eq!(result, SearchResult::at_most(0..=99)); // 6. Equals query: 100 (should match only last row) let query = SargableQuery::Equals(ScalarValue::Int32(Some(100))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!( - result, - SearchResult::AtMost(RowIdTreeMap::from_iter(100..=100)) - ); + assert_eq!(result, SearchResult::at_most(100..=100)); // 7. Equals query: 101 (should match nothing) let query = SargableQuery::Equals(ScalarValue::Int32(Some(101))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // 8. IsNull query (no nulls in data, should match nothing) let query = SargableQuery::IsNull(); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); - + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // 9. IsIn query: [0, 100, 101, 50] let query = SargableQuery::IsIn(vec![ ScalarValue::Int32(Some(0)), @@ -1631,10 +1699,7 @@ mod tests { ]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); // 0 and 50 are in the first zone, 100 in the second, 101 is not present - assert_eq!( - result, - SearchResult::AtMost(RowIdTreeMap::from_iter(0..=100)) - ); + assert_eq!(result, SearchResult::at_most(0..=100)); // 10. IsIn query: [101, 102] (should match nothing) let query = SargableQuery::IsIn(vec![ @@ -1642,17 +1707,17 @@ mod tests { ScalarValue::Int32(Some(102)), ]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // 11. IsIn query: [null] (should match nothing, as there are no nulls) let query = SargableQuery::IsIn(vec![ScalarValue::Int32(None)]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // 12. Equals query: null (should match nothing, as there are no nulls) let query = SargableQuery::Equals(ScalarValue::Int32(None)); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } #[tokio::test] @@ -1748,7 +1813,7 @@ mod tests { // Should match row 1000 in fragment 0: row address = (0 << 32) + 1000 = 1000 let mut expected = RowIdTreeMap::new(); expected.insert_range(0..=8191); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Search for a value in the second zone let query = SargableQuery::Equals(ScalarValue::Int64(Some(9000))); @@ -1756,12 +1821,12 @@ mod tests { // Should match row 9000 in fragment 0: row address = (0 << 32) + 9000 = 9000 let mut expected = RowIdTreeMap::new(); expected.insert_range(8192..=16383); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Search for a value not present in any zone let query = SargableQuery::Equals(ScalarValue::Int64(Some(20000))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // Search for a range that spans multiple zones let query = SargableQuery::Range( @@ -1772,7 +1837,7 @@ mod tests { // Should match all rows from 8000 to 16400 (inclusive) let mut expected = RowIdTreeMap::new(); expected.insert_range(8192..=16425); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); } #[tokio::test] @@ -1984,7 +2049,7 @@ mod tests { expected.insert_range(5000..8192); // zone 2 expected.insert_range((1u64 << 32)..((1u64 << 32) + 5000)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test exact match query from zone 2 let query = SargableQuery::Equals(ScalarValue::Int64(Some(8192))); @@ -1992,7 +2057,7 @@ mod tests { // Should include zone 2 since it contains value 8192 let mut expected = RowIdTreeMap::new(); expected.insert_range((1u64 << 32)..((1u64 << 32) + 5000)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test exact match query from zone 4 let query = SargableQuery::Equals(ScalarValue::Int64(Some(16385))); @@ -2000,19 +2065,19 @@ mod tests { // Should include zone 4 since it contains value 16385 let mut expected = RowIdTreeMap::new(); expected.insert_range(2u64 << 32..((2u64 << 32) + 42)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test query that matches nothing let query = SargableQuery::Equals(ScalarValue::Int64(Some(99999))); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); // Test is_in query let query = SargableQuery::IsIn(vec![ScalarValue::Int64(Some(16385))]); let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); let mut expected = RowIdTreeMap::new(); expected.insert_range(2u64 << 32..((2u64 << 32) + 42)); - assert_eq!(result, SearchResult::AtMost(expected)); + assert_eq!(result, SearchResult::at_most(expected)); // Test equals query with null let query = SargableQuery::Equals(ScalarValue::Int64(None)); @@ -2020,7 +2085,7 @@ mod tests { let mut expected = RowIdTreeMap::new(); expected.insert_range(0..=16425); // expected = {:?}", expected - assert_eq!(result, SearchResult::AtMost(RowIdTreeMap::new())); + assert_eq!(result, SearchResult::at_most(RowIdTreeMap::new())); } // Each fragment is its own batch diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 16787aa877..f24816710b 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -118,7 +118,7 @@ async fn warm_indexed_equality_search(index: &dyn ScalarIndex) { let SearchResult::Exact(row_ids) = result else { panic!("Expected exact results") }; - assert_eq!(row_ids.len(), Some(1)); + assert_eq!(row_ids.true_rows().len(), Some(1)); } async fn baseline_inequality_search(fixture: &BenchmarkFixture) { @@ -155,7 +155,7 @@ async fn warm_indexed_inequality_search(index: &dyn ScalarIndex) { }; // 100Mi - 50M = 54,857,600 - assert_eq!(row_ids.len(), Some(54857600)); + assert_eq!(row_ids.true_rows().len(), Some(54857600)); } async fn warm_indexed_isin_search(index: &dyn ScalarIndex) { @@ -176,7 +176,7 @@ async fn warm_indexed_isin_search(index: &dyn ScalarIndex) { }; // Only 3 because 150M is not in dataset - assert_eq!(row_ids.len(), Some(3)); + assert_eq!(row_ids.true_rows().len(), Some(3)); } fn bench_baseline(c: &mut Criterion) { diff --git a/rust/lance/src/index/prefilter.rs b/rust/lance/src/index/prefilter.rs index 9c5c2ecc44..269fdc83bb 100644 --- a/rust/lance/src/index/prefilter.rs +++ b/rust/lance/src/index/prefilter.rs @@ -351,7 +351,7 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.block_list.as_ref().and_then(|x| x.len()), Some(1)); // There was just one row deleted. + assert_eq!(mask.block_list().and_then(|x| x.len()), Some(1)); // There was just one row deleted. // If there are deletions and missing fragments, we should get a mask let mask = DatasetPreFilter::create_deletion_mask( @@ -362,7 +362,7 @@ mod test { let mask = mask.unwrap().await.unwrap(); let mut expected = RowIdTreeMap::from_iter(vec![(2 << 32) + 2]); expected.insert_fragment(1); - assert_eq!(&mask.block_list, &Some(expected)); + assert_eq!(mask.block_list(), Some(&expected)); // If we don't pass the missing fragment id, we should get a smaller mask. let mask = DatasetPreFilter::create_deletion_mask( @@ -371,7 +371,7 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.block_list.as_ref().and_then(|x| x.len()), Some(1)); + assert_eq!(mask.block_list().and_then(|x| x.len()), Some(1)); // If there are only missing fragments, we should still get a mask let mask = DatasetPreFilter::create_deletion_mask( @@ -383,7 +383,7 @@ mod test { let mut expected = RowIdTreeMap::new(); expected.insert_fragment(1); expected.insert_fragment(2); - assert_eq!(&mask.block_list, &Some(expected)); + assert_eq!(mask.block_list(), Some(&expected)); } #[tokio::test] @@ -406,7 +406,7 @@ mod test { assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); let expected = RowIdTreeMap::from_iter(0..8); - assert_eq!(mask.allow_list, Some(expected)); // There was just one row deleted. + assert_eq!(mask.allow_list(), Some(&expected)); // There was just one row deleted. // If there are deletions and missing fragments, we should get an allow list let mask = DatasetPreFilter::create_deletion_mask( @@ -415,7 +415,7 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.allow_list.as_ref().and_then(|x| x.len()), Some(5)); // There were five rows left over; + assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(5)); // There were five rows left over; // If there are only missing fragments, we should get an allow list let mask = DatasetPreFilter::create_deletion_mask( @@ -424,6 +424,6 @@ mod test { ); assert!(mask.is_some()); let mask = mask.unwrap().await.unwrap(); - assert_eq!(mask.allow_list.as_ref().and_then(|x| x.len()), Some(3)); // There were three rows left over; + assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(3)); // There were three rows left over; } } diff --git a/rust/lance/src/io/exec/filtered_read.rs b/rust/lance/src/io/exec/filtered_read.rs index f97ebcdadb..4db065ec99 100644 --- a/rust/lance/src/io/exec/filtered_read.rs +++ b/rust/lance/src/io/exec/filtered_read.rs @@ -81,7 +81,7 @@ impl EvaluatedIndex { if batch.num_rows() != 2 { return Err(Error::InvalidInput { source: format!( - "Expected a batch with exactly one row but there are {} rows", + "Expected a batch with exactly 2 rows but there are {} rows", batch.num_rows() ) .into(), diff --git a/rust/lance/src/io/exec/scalar_index.rs b/rust/lance/src/io/exec/scalar_index.rs index 26c69d1f5c..496bb8ff3d 100644 --- a/rust/lance/src/io/exec/scalar_index.rs +++ b/rust/lance/src/io/exec/scalar_index.rs @@ -320,29 +320,16 @@ impl MapIndexExec { row_id_mask = row_id_mask & deletion_mask.as_ref().clone(); } - if let Some(mut allow_list) = row_id_mask.allow_list { - // Flatten the allow list - if let Some(block_list) = row_id_mask.block_list { - allow_list -= &block_list; - } - - let allow_list = - allow_list - .row_ids() - .ok_or(datafusion::error::DataFusionError::External( - "IndexedLookupExec: row addresses didn't have an iterable allow list" - .into(), - ))?; - let allow_list: UInt64Array = allow_list.map(u64::from).collect(); - Ok(RecordBatch::try_new( - INDEX_LOOKUP_SCHEMA.clone(), - vec![Arc::new(allow_list)], - )?) - } else { - Err(datafusion::error::DataFusionError::Internal( - "IndexedLookupExec: row addresses didn't have an allow list".to_string(), - )) - } + let row_id_iter = row_id_mask + .iter_ids() + .ok_or(datafusion::error::DataFusionError::Internal( + "IndexedLookupExec: Cannot iterate over row addresses (BlockList or contains full fragments)".to_string(), + ))?; + let allow_list: UInt64Array = row_id_iter.map(u64::from).collect(); + Ok(RecordBatch::try_new( + INDEX_LOOKUP_SCHEMA.clone(), + vec![Arc::new(allow_list)], + )?) } async fn do_execute( @@ -591,8 +578,8 @@ async fn row_ids_for_mask( dataset: &Dataset, fragments: &[Fragment], ) -> Result> { - match (mask.allow_list, mask.block_list) { - (None, None) => { + match mask { + RowIdMask::BlockList(block_list) if block_list.is_empty() => { // Matches all row ids in the given fragments. if dataset.manifest.uses_stable_row_ids() { let sequences = load_row_id_sequences(dataset, fragments) @@ -610,7 +597,7 @@ async fn row_ids_for_mask( Ok(FragIdIter::new(fragments).collect::>()) } } - (Some(mut allow_list), None) => { + RowIdMask::AllowList(mut allow_list) => { retain_fragments(&mut allow_list, fragments, dataset).await?; if let Some(allow_list_iter) = allow_list.row_ids() { @@ -623,7 +610,7 @@ async fn row_ids_for_mask( .collect()) } } - (None, Some(block_list)) => { + RowIdMask::BlockList(block_list) => { if dataset.manifest.uses_stable_row_ids() { let sequences = load_row_id_sequences(dataset, fragments) .map_ok(|(_frag_id, sequence)| sequence) @@ -647,29 +634,6 @@ async fn row_ids_for_mask( .collect()) } } - (Some(mut allow_list), Some(block_list)) => { - // We need to filter out irrelevant fragments as well. - retain_fragments(&mut allow_list, fragments, dataset).await?; - - if let Some(allow_list_iter) = allow_list.row_ids() { - Ok(allow_list_iter - .filter_map(|addr| { - let row_id = u64::from(addr); - if !block_list.contains(row_id) { - Some(row_id) - } else { - None - } - }) - .collect::>()) - } else { - // We shouldn't hit this branch if the row ids are stable. - debug_assert!(!dataset.manifest.uses_stable_row_ids()); - Ok(FragIdIter::new(fragments) - .filter(|row_id| !block_list.contains(*row_id) && allow_list.contains(*row_id)) - .collect()) - } - } } }