Skip to content

Commit

Permalink
Speeding things up further by avoiding rank computations and bit acce…
Browse files Browse the repository at this point in the history
…ss from scratch.

Instead, we check whether the access still belongs within the same word that we cache locally.
  • Loading branch information
aneubeck committed Oct 31, 2024
1 parent 2711809 commit e6ad464
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 38 deletions.
73 changes: 49 additions & 24 deletions crates/quaternary_trie/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use virtual_bitrank::VirtualBitRank;
use virtual_bitrank::{VirtualBitRank, Word, WORD_BITS};

mod virtual_bitrank;

Expand Down Expand Up @@ -178,7 +178,7 @@ impl QuarternaryTrie {
};
let mut consumed = values;
s.count_levels(&mut consumed, MAX_LEVEL - 1);
if matches!(layout, Layout::Linear) {
if true || matches!(layout, Layout::Linear) {
s.data.reserve(s.level_idx.iter().sum::<usize>() * 4);
}
s.level_idx
Expand Down Expand Up @@ -325,80 +325,105 @@ pub trait TrieIteratorTrait {

pub struct TrieTraversal<'a> {
trie: &'a QuarternaryTrie,
// The nibble position of the node for each level.
pos: [u32; MAX_LEVEL],
// The remaining bits (nibbles) of the word covering the nibble position.
word: [Word; MAX_LEVEL],
// The 1-rank up to the nibble. This information is needed
// to determine the nibble/node position of the next level.
rank: [u32; MAX_LEVEL],
}

impl<'a> TrieTraversal<'a> {
pub fn new(bpt: &'a QuarternaryTrie) -> Self {
pub fn new(trie: &'a QuarternaryTrie) -> Self {
let word = trie.data.get_word_suffix(0);
Self {
trie: bpt,
trie,
pos: [0; MAX_LEVEL],
word: [word; MAX_LEVEL],
rank: [0; MAX_LEVEL],
}
}
}

impl TrieIteratorTrait for TrieTraversal<'_> {
fn get(&self, level: usize) -> u32 {
self.trie.data.get_nibble(self.pos[level] as usize)
self.word[level] as u32 & 15
}

fn down(&mut self, level: usize, child: u32) {
let index = self.pos[level] * 4 + child;
let new_index = self.trie.data.rank(index as usize + 1);
self.pos[level - 1] = new_index;
let new_pos =
self.rank[level] + (self.word[level] & !(Word::MAX << (child + 1))).count_ones();
let old_pos = self.pos[level - 1];
if (new_pos ^ old_pos) & !(WORD_BITS as u32 / 4 - 1) == 0 {
// In this case, we can reuse the old rank information
let delta = (new_pos - old_pos) * 4;
self.rank[level - 1] += (self.word[level - 1] & !(Word::MAX << delta)).count_ones();
self.word[level - 1] = self.word[level - 1] >> delta;
} else {
if level > 1 {
// for level 0, we don't need the rank information
// self.rank[level - 1] = self.trie.data.rank(4 * new_pos as usize);
let (r, w) = self.trie.data.rank_with_word(4 * new_pos as usize);
self.rank[level - 1] = r;
self.word[level - 1] = w;
} else {
// TODO: Get word suffix and rank information in one go...
self.word[level - 1] = self.trie.data.get_word_suffix(4 * new_pos as usize);
}
}
self.pos[level - 1] = new_pos;
}
}

pub struct TrieIterator<T> {
trie: T,
item: u32,
nibbles: [u32; MAX_LEVEL],
}

impl<T: TrieIteratorTrait> TrieIterator<T> {
pub fn new(trie: T) -> Self {
Self {
trie,
item: 0,
nibbles: [0; MAX_LEVEL],
}
Self { trie, item: 0 }
}
}

impl<'a, T: TrieIteratorTrait> Iterator for TrieIterator<T> {
type Item = u32;

fn next(&mut self) -> Option<u32> {
let mut item = self.item;
let mut level = if self.item == 0 {
self.nibbles[MAX_LEVEL - 1] = self.trie.get(MAX_LEVEL - 1);
MAX_LEVEL - 1
} else {
(self.item.trailing_zeros() / 2) as usize
(item.trailing_zeros() / 2) as usize
};
while level < MAX_LEVEL {
let child = (self.item >> (2 * level)) & 3;
let nibble = self.nibbles[level] >> child;
let child = (item >> (2 * level)) & 3;
let nibble = self.trie.get(level) >> child;
if nibble != 0 {
let delta = nibble.trailing_zeros();
if level == 0 {
let res = self.item + delta;
let res = item + delta;
self.item = res + 1;
return Some(res);
}
self.item += delta << (2 * level);
item += delta << (2 * level);
self.trie.down(level, child + delta);
level -= 1;
self.nibbles[level] = self.trie.get(level);
} else {
self.item |= 3 << (level * 2);
self.item += 1 << (level * 2);
level = (self.item.trailing_zeros() / 2) as usize;
item |= 3 << (level * 2);
item += 1 << (level * 2);
level = (item.trailing_zeros() / 2) as usize;
}
}
self.item = item;
None
}
}

// TODO: Introduce a nibble summary structure which caches the computed merged nibble information.
// If the query tree becomes more complex, recomputing the merged nibble information becomes expensive.
// But for small query trees, it's not worth the effort to cache the information.
pub struct Intersection<T> {
left: T,
right: T,
Expand Down
61 changes: 47 additions & 14 deletions crates/quaternary_trie/src/virtual_bitrank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@

use std::cell::RefCell;

type Word = u64;
pub(crate) type Word = u64;

const BLOCK_BYTES: usize = 64;
const BLOCK_BITS: usize = BLOCK_BYTES * 8;
const BLOCKS_PER_PAGE: usize = BLOCK_BYTES / 4;
const WORD_BITS: usize = WORD_BYTES * 8;
const WORD_BYTES: usize = std::mem::size_of::<Word>();
pub(crate) const WORD_BITS: usize = WORD_BYTES * 8;
pub(crate) const WORD_BYTES: usize = std::mem::size_of::<Word>();
const WORDS_PER_BLOCK: usize = BLOCK_BYTES / WORD_BYTES;
const PAGE_BYTES: usize = BLOCKS_PER_PAGE * BLOCK_BYTES;
const PAGE_BITS: usize = PAGE_BYTES * 8;
Expand Down Expand Up @@ -94,14 +94,14 @@ impl VirtualBitRank {
}

fn bit_to_block(&self, bit: usize) -> usize {
//let block = bit / BLOCK_BITS;
//let result2 = block + (block / (BLOCKS_PER_PAGE - 1)) + 1;
let result = self.block_mapping[bit / BLOCK_BITS] as usize;
let block = bit / BLOCK_BITS;
let result2 = block + (block / (BLOCKS_PER_PAGE - 1)) + 1;
//let result = self.block_mapping[bit / BLOCK_BITS] as usize;
//assert_eq!(result2, result);
if let Some(v) = self.stats.get(result * BLOCK_BITS / SUPER_PAGE_BITS / 64) {
*v.borrow_mut() += 1 << (result % 64);
}
result
//if let Some(v) = self.stats.get(result * BLOCK_BITS / SUPER_PAGE_BITS / 64) {
// *v.borrow_mut() += 1 << (result % 64);
//}
result2
}

fn mid_rank(&self, block: usize) -> u32 {
Expand Down Expand Up @@ -132,6 +132,33 @@ impl VirtualBitRank {
}
}

pub(crate) fn rank_with_word(&self, bit: usize) -> (u32, Word) {
let block = self.bit_to_block(bit);
let mut rank = self.mid_rank(block);
let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1);
let bit_in_word = bit & (WORD_BITS - 1);
if word >= WORDS_PER_BLOCK / 2 {
for i in WORDS_PER_BLOCK / 2..word {
rank += self.blocks[block].words[i].count_ones();
}
if bit_in_word != 0 {
(
rank + (self.blocks[block].words[word] << (WORD_BITS - bit_in_word))
.count_ones(),
self.blocks[block].words[word] >> bit_in_word,
)
} else {
(rank, self.blocks[block].words[word])
}
} else {
for i in word + 1..WORDS_PER_BLOCK / 2 {
rank -= self.blocks[block].words[i].count_ones();
}
let w = self.blocks[block].words[word] >> bit_in_word;
(rank - w.count_ones(), w)
}
}

pub(crate) fn reserve(&mut self, bits: usize) {
assert!(self.block_mapping.is_empty());
assert!(self.blocks.is_empty());
Expand Down Expand Up @@ -202,7 +229,13 @@ impl VirtualBitRank {
}
}

pub(crate) fn get(&self, bit: usize) -> bool {
pub(crate) fn get_word_suffix(&self, bit: usize) -> Word {
let block = self.bit_to_block(bit);
let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1);
self.blocks[block].words[word] >> bit
}

pub(crate) fn get_bit(&self, bit: usize) -> bool {
let block = self.bit_to_block(bit);
let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1);
let bit_in_word = bit & (WORD_BITS - 1);
Expand Down Expand Up @@ -247,7 +280,7 @@ mod tests {
}
bitrank.build();
for (i, bit) in bits.iter().enumerate() {
assert_eq!(bitrank.get(i), *bit, "at position {i}");
assert_eq!(bitrank.get_bit(i), *bit, "at position {i}");
}
for (i, r) in rank.iter().enumerate() {
assert_eq!(bitrank.rank(i), *r, "at position {i}");
Expand Down Expand Up @@ -282,7 +315,7 @@ mod tests {
for _ in 0..4 {
let time = Instant::now();
for i in &random_bits {
assert!(bitrank.get(*i as usize), "at position {i}");
assert!(bitrank.get_bit(*i as usize), "at position {i}");
}
println!(
"time to check random bits: {:?} {:?}",
Expand All @@ -292,7 +325,7 @@ mod tests {

let time = Instant::now();
for i in &sorted_bits {
assert!(bitrank.get(*i as usize), "at position {i}");
assert!(bitrank.get_bit(*i as usize), "at position {i}");
}
println!(
"time to check sorted bits: {:?} {:?}",
Expand Down

0 comments on commit e6ad464

Please sign in to comment.