From e6ad464e736680e6e54f2e91e2878aad4d191319 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Thu, 31 Oct 2024 12:36:46 +0000 Subject: [PATCH] Speeding things up further by avoiding rank computations and bit access from scratch. Instead, we check whether the access still belongs within the same word that we cache locally. --- crates/quaternary_trie/src/lib.rs | 73 +++++++++++++------ crates/quaternary_trie/src/virtual_bitrank.rs | 61 ++++++++++++---- 2 files changed, 96 insertions(+), 38 deletions(-) diff --git a/crates/quaternary_trie/src/lib.rs b/crates/quaternary_trie/src/lib.rs index d8752f1..1981dda 100644 --- a/crates/quaternary_trie/src/lib.rs +++ b/crates/quaternary_trie/src/lib.rs @@ -1,4 +1,4 @@ -use virtual_bitrank::VirtualBitRank; +use virtual_bitrank::{VirtualBitRank, Word, WORD_BITS}; mod virtual_bitrank; @@ -178,7 +178,7 @@ impl QuarternaryTrie { }; let mut consumed = values; s.count_levels(&mut consumed, MAX_LEVEL - 1); - if matches!(layout, Layout::Linear) { + if true || matches!(layout, Layout::Linear) { s.data.reserve(s.level_idx.iter().sum::() * 4); } s.level_idx @@ -325,43 +325,65 @@ pub trait TrieIteratorTrait { pub struct TrieTraversal<'a> { trie: &'a QuarternaryTrie, + // The nibble position of the node for each level. pos: [u32; MAX_LEVEL], + // The remaining bits (nibbles) of the word covering the nibble position. + word: [Word; MAX_LEVEL], + // The 1-rank up to the nibble. This information is needed + // to determine the nibble/node position of the next level. + rank: [u32; MAX_LEVEL], } impl<'a> TrieTraversal<'a> { - pub fn new(bpt: &'a QuarternaryTrie) -> Self { + pub fn new(trie: &'a QuarternaryTrie) -> Self { + let word = trie.data.get_word_suffix(0); Self { - trie: bpt, + trie, pos: [0; MAX_LEVEL], + word: [word; MAX_LEVEL], + rank: [0; MAX_LEVEL], } } } impl TrieIteratorTrait for TrieTraversal<'_> { fn get(&self, level: usize) -> u32 { - self.trie.data.get_nibble(self.pos[level] as usize) + self.word[level] as u32 & 15 } fn down(&mut self, level: usize, child: u32) { - let index = self.pos[level] * 4 + child; - let new_index = self.trie.data.rank(index as usize + 1); - self.pos[level - 1] = new_index; + let new_pos = + self.rank[level] + (self.word[level] & !(Word::MAX << (child + 1))).count_ones(); + let old_pos = self.pos[level - 1]; + if (new_pos ^ old_pos) & !(WORD_BITS as u32 / 4 - 1) == 0 { + // In this case, we can reuse the old rank information + let delta = (new_pos - old_pos) * 4; + self.rank[level - 1] += (self.word[level - 1] & !(Word::MAX << delta)).count_ones(); + self.word[level - 1] = self.word[level - 1] >> delta; + } else { + if level > 1 { + // for level 0, we don't need the rank information + // self.rank[level - 1] = self.trie.data.rank(4 * new_pos as usize); + let (r, w) = self.trie.data.rank_with_word(4 * new_pos as usize); + self.rank[level - 1] = r; + self.word[level - 1] = w; + } else { + // TODO: Get word suffix and rank information in one go... + self.word[level - 1] = self.trie.data.get_word_suffix(4 * new_pos as usize); + } + } + self.pos[level - 1] = new_pos; } } pub struct TrieIterator { trie: T, item: u32, - nibbles: [u32; MAX_LEVEL], } impl TrieIterator { pub fn new(trie: T) -> Self { - Self { - trie, - item: 0, - nibbles: [0; MAX_LEVEL], - } + Self { trie, item: 0 } } } @@ -369,36 +391,39 @@ impl<'a, T: TrieIteratorTrait> Iterator for TrieIterator { type Item = u32; fn next(&mut self) -> Option { + let mut item = self.item; let mut level = if self.item == 0 { - self.nibbles[MAX_LEVEL - 1] = self.trie.get(MAX_LEVEL - 1); MAX_LEVEL - 1 } else { - (self.item.trailing_zeros() / 2) as usize + (item.trailing_zeros() / 2) as usize }; while level < MAX_LEVEL { - let child = (self.item >> (2 * level)) & 3; - let nibble = self.nibbles[level] >> child; + let child = (item >> (2 * level)) & 3; + let nibble = self.trie.get(level) >> child; if nibble != 0 { let delta = nibble.trailing_zeros(); if level == 0 { - let res = self.item + delta; + let res = item + delta; self.item = res + 1; return Some(res); } - self.item += delta << (2 * level); + item += delta << (2 * level); self.trie.down(level, child + delta); level -= 1; - self.nibbles[level] = self.trie.get(level); } else { - self.item |= 3 << (level * 2); - self.item += 1 << (level * 2); - level = (self.item.trailing_zeros() / 2) as usize; + item |= 3 << (level * 2); + item += 1 << (level * 2); + level = (item.trailing_zeros() / 2) as usize; } } + self.item = item; None } } +// TODO: Introduce a nibble summary structure which caches the computed merged nibble information. +// If the query tree becomes more complex, recomputing the merged nibble information becomes expensive. +// But for small query trees, it's not worth the effort to cache the information. pub struct Intersection { left: T, right: T, diff --git a/crates/quaternary_trie/src/virtual_bitrank.rs b/crates/quaternary_trie/src/virtual_bitrank.rs index 9a0852d..ccbd84b 100644 --- a/crates/quaternary_trie/src/virtual_bitrank.rs +++ b/crates/quaternary_trie/src/virtual_bitrank.rs @@ -38,13 +38,13 @@ use std::cell::RefCell; -type Word = u64; +pub(crate) type Word = u64; const BLOCK_BYTES: usize = 64; const BLOCK_BITS: usize = BLOCK_BYTES * 8; const BLOCKS_PER_PAGE: usize = BLOCK_BYTES / 4; -const WORD_BITS: usize = WORD_BYTES * 8; -const WORD_BYTES: usize = std::mem::size_of::(); +pub(crate) const WORD_BITS: usize = WORD_BYTES * 8; +pub(crate) const WORD_BYTES: usize = std::mem::size_of::(); const WORDS_PER_BLOCK: usize = BLOCK_BYTES / WORD_BYTES; const PAGE_BYTES: usize = BLOCKS_PER_PAGE * BLOCK_BYTES; const PAGE_BITS: usize = PAGE_BYTES * 8; @@ -94,14 +94,14 @@ impl VirtualBitRank { } fn bit_to_block(&self, bit: usize) -> usize { - //let block = bit / BLOCK_BITS; - //let result2 = block + (block / (BLOCKS_PER_PAGE - 1)) + 1; - let result = self.block_mapping[bit / BLOCK_BITS] as usize; + let block = bit / BLOCK_BITS; + let result2 = block + (block / (BLOCKS_PER_PAGE - 1)) + 1; + //let result = self.block_mapping[bit / BLOCK_BITS] as usize; //assert_eq!(result2, result); - if let Some(v) = self.stats.get(result * BLOCK_BITS / SUPER_PAGE_BITS / 64) { - *v.borrow_mut() += 1 << (result % 64); - } - result + //if let Some(v) = self.stats.get(result * BLOCK_BITS / SUPER_PAGE_BITS / 64) { + // *v.borrow_mut() += 1 << (result % 64); + //} + result2 } fn mid_rank(&self, block: usize) -> u32 { @@ -132,6 +132,33 @@ impl VirtualBitRank { } } + pub(crate) fn rank_with_word(&self, bit: usize) -> (u32, Word) { + let block = self.bit_to_block(bit); + let mut rank = self.mid_rank(block); + let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1); + let bit_in_word = bit & (WORD_BITS - 1); + if word >= WORDS_PER_BLOCK / 2 { + for i in WORDS_PER_BLOCK / 2..word { + rank += self.blocks[block].words[i].count_ones(); + } + if bit_in_word != 0 { + ( + rank + (self.blocks[block].words[word] << (WORD_BITS - bit_in_word)) + .count_ones(), + self.blocks[block].words[word] >> bit_in_word, + ) + } else { + (rank, self.blocks[block].words[word]) + } + } else { + for i in word + 1..WORDS_PER_BLOCK / 2 { + rank -= self.blocks[block].words[i].count_ones(); + } + let w = self.blocks[block].words[word] >> bit_in_word; + (rank - w.count_ones(), w) + } + } + pub(crate) fn reserve(&mut self, bits: usize) { assert!(self.block_mapping.is_empty()); assert!(self.blocks.is_empty()); @@ -202,7 +229,13 @@ impl VirtualBitRank { } } - pub(crate) fn get(&self, bit: usize) -> bool { + pub(crate) fn get_word_suffix(&self, bit: usize) -> Word { + let block = self.bit_to_block(bit); + let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1); + self.blocks[block].words[word] >> bit + } + + pub(crate) fn get_bit(&self, bit: usize) -> bool { let block = self.bit_to_block(bit); let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1); let bit_in_word = bit & (WORD_BITS - 1); @@ -247,7 +280,7 @@ mod tests { } bitrank.build(); for (i, bit) in bits.iter().enumerate() { - assert_eq!(bitrank.get(i), *bit, "at position {i}"); + assert_eq!(bitrank.get_bit(i), *bit, "at position {i}"); } for (i, r) in rank.iter().enumerate() { assert_eq!(bitrank.rank(i), *r, "at position {i}"); @@ -282,7 +315,7 @@ mod tests { for _ in 0..4 { let time = Instant::now(); for i in &random_bits { - assert!(bitrank.get(*i as usize), "at position {i}"); + assert!(bitrank.get_bit(*i as usize), "at position {i}"); } println!( "time to check random bits: {:?} {:?}", @@ -292,7 +325,7 @@ mod tests { let time = Instant::now(); for i in &sorted_bits { - assert!(bitrank.get(*i as usize), "at position {i}"); + assert!(bitrank.get_bit(*i as usize), "at position {i}"); } println!( "time to check sorted bits: {:?} {:?}",