Skip to content

Commit

Permalink
Add union implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Oct 31, 2024
1 parent e6ad464 commit 6a9a1f0
Showing 1 changed file with 133 additions and 1 deletion.
134 changes: 133 additions & 1 deletion crates/quaternary_trie/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,14 +446,112 @@ impl<T: TrieIteratorTrait> TrieIteratorTrait for Intersection<T> {
}
}

enum Split {
None,
Left(usize),
Right(usize),
}

pub struct Union<T> {
inner: [T; 2],
split: usize,
swap: bool,
}

impl<T: TrieIteratorTrait> Union<T> {
pub fn new(left: T, right: T) -> Self {
Self {
inner: [left, right],
swap: false,
split: 0,
}
}
}

impl<T: TrieIteratorTrait> TrieIteratorTrait for Union<T> {
fn get(&self, level: usize) -> u32 {
/*match self.split {
Split::None => self.left.get(level) | self.right.get(level),
Split::Left(split) => {
if level < split {
self.right.get(level)
} else {
self.left.get(level) | self.right.get(level)
}
}
Split::Right(split) => {
if level < split {
self.left.get(level)
} else {
self.left.get(level) | self.right.get(level)
}
}
}*/
if level < self.split {
self.inner[self.swap as usize].get(level)
} else {
self.inner[0].get(level) | self.inner[1].get(level)
}
}

fn down(&mut self, level: usize, child: u32) {
// TODO: Only traverse the side which has the child bit set.
/*match self.split {
Split::Left(l) if level < l => {
self.right.down(level, child);
}
Split::Right(l) if level < l => {
self.left.down(level, child);
}
_ => {
let has_left = self.left.get(level) & (1 << child) != 0;
let has_right = self.right.get(level) & (1 << child) != 0;
if has_left && has_right {
self.split = Split::None;
self.left.down(level, child);
self.right.down(level, child);
} else if has_left {
self.split = Split::Right(level);
self.left.down(level, child);
} else {
assert!(has_right);
self.split = Split::Left(level);
//std::mem::swap(&mut self.left, &mut self.right);
self.right.down(level, child);
}
}
}*/
if level < self.split {
self.inner[self.swap as usize].down(level, child);
} else {
let has_left = self.inner[0].get(level) & (1 << child) != 0;
let has_right = self.inner[1].get(level) & (1 << child) != 0;
if has_left && has_right {
self.split = 0;
self.inner[0].down(level, child);
self.inner[1].down(level, child);
} else if has_left {
self.split = level;
self.swap = false;
self.inner[0].down(level, child);
} else {
assert!(has_right);
self.split = level;
self.swap = true;
self.inner[1].down(level, child);
}
}
}
}

#[cfg(test)]
mod tests {
use std::time::Instant;

use itertools::{kmerge, Itertools};
use rand::{thread_rng, Rng};

use crate::{Intersection, Layout, QuarternaryTrie, TrieIterator, TrieTraversal};
use crate::{Intersection, Layout, QuarternaryTrie, TrieIterator, TrieTraversal, Union};

#[test]
fn test_trie() {
Expand Down Expand Up @@ -497,6 +595,40 @@ mod tests {
assert_eq!(trie.collect(), values);
}

#[test]
fn test_union() {
let mut values: Vec<_> = (0..1000000)
.map(|_| thread_rng().gen_range(0..100000000))
.collect();
values.sort();
values.dedup();

let mut values2: Vec<_> = (0..10000000)
.map(|_| thread_rng().gen_range(0..100000000))
.collect();
values2.sort();
values2.dedup();

let start = Instant::now();
let union: Vec<_> = kmerge([values.iter(), values2.iter()])
.copied()
.dedup()
.collect();
println!("kmerge union {:?}", start.elapsed() / union.len() as u32);
println!("Union size: {}", union.len());

let trie = QuarternaryTrie::new(&values, Layout::Linear);
let trie2 = QuarternaryTrie::new(&values2, Layout::Linear);
let iter = TrieIterator::new(Union::new(
TrieTraversal::new(&trie),
TrieTraversal::new(&trie2),
));
let start = Instant::now();
let result: Vec<_> = iter.collect();
println!("trie union {:?}", start.elapsed() / union.len() as u32,);
assert_eq!(result, union);
}

#[test]
fn test_intersection() {
let mut page_counts = [0, 0, 0];
Expand Down

0 comments on commit 6a9a1f0

Please sign in to comment.