diff --git a/.gitignore b/.gitignore index 3f0956c..0654041 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ Cargo.lock /target/ /crates/*/target/ /crates/*/Cargo.lock -.vscode/ \ No newline at end of file +.vscode/ diff --git a/crates/bpe/benches/counting.rs b/crates/bpe/benches/counting.rs index 61b718b..9b746d3 100644 --- a/crates/bpe/benches/counting.rs +++ b/crates/bpe/benches/counting.rs @@ -6,85 +6,100 @@ use criterion::{criterion_group, criterion_main, Criterion}; use rand::{thread_rng, Rng}; fn counting_benchmark(c: &mut Criterion) { - let bpe = BytePairEncoding::cl100k(); - let text = create_test_bytes(&bpe, 20000); + for (name, bpe) in [ + ("cl100k", BytePairEncoding::cl100k()), + ("o200k", BytePairEncoding::o200k()), + ] { + let text = create_test_bytes(&bpe, 20000); + let fast = IntervalEncoding::new(&bpe, &text); - let fast = IntervalEncoding::new(&bpe, &text); - - for bytes in [10, 100, 1000, 10000] { - let mut group = c.benchmark_group(format!("bytes-{bytes}")); - group.bench_function("hybrid counting", |b| { - b.iter_batched( - || thread_rng().gen_range(0..text.len() - bytes), - |start| fast.count(start..start + bytes), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("backtrack counting", |b| { - b.iter_batched( - || thread_rng().gen_range(0..text.len() - bytes), - |start| bpe.count(&text[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); + for bytes in [10, 100, 1000, 10000] { + let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}")); + group.bench_function("hybrid counting", |b| { + b.iter_batched( + || thread_rng().gen_range(0..text.len() - bytes), + |start| fast.count(start..start + bytes), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_function("backtrack counting", |b| { + b.iter_batched( + || thread_rng().gen_range(0..text.len() - bytes), + |start| bpe.count(&text[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + } } } fn encoding_benchmark(c: &mut Criterion) { - let bpe = BytePairEncoding::cl100k(); - let tiktoken = tiktoken_rs::cl100k_base().unwrap(); - let text = create_test_string(&bpe, 20000); - let input = text.as_bytes(); + for (name, bpe, tiktoken) in [ + ( + "cl100k", + BytePairEncoding::cl100k(), + tiktoken_rs::cl100k_base().unwrap(), + ), + ( + "o200k", + BytePairEncoding::o200k(), + tiktoken_rs::o200k_base().unwrap(), + ), + ] { + let text = create_test_string(&bpe, 20000); + let input = text.as_bytes(); - for bytes in [10, 100, 1000, 10000] { - let mut group = c.benchmark_group(format!("bytes-{bytes}")); - group.bench_function("backtracking", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_via_backtracking(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("heap", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_via_bitfield(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("dynamic programming", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_via_table(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("greedy", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_greedy(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("minimal", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_minimal(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("tiktoken", |b| { - b.iter_batched( - || loop { - let start = thread_rng().gen_range(0..input.len() - bytes - 1); - if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes]) { - return start; - } - }, - |start| tiktoken.encode_ordinary(&text[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); + for bytes in [10, 100, 1000, 10000] { + let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}")); + group.bench_function("backtracking", |b| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_via_backtracking(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_function("heap", |b| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_via_bitfield(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_function("dynamic programming", |b| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_via_table(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_function("greedy", |b| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_greedy(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_function("minimal", |b| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_minimal(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_function("tiktoken", |b| { + b.iter_batched( + || loop { + let start = thread_rng().gen_range(0..input.len() - bytes - 1); + if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes]) + { + return start; + } + }, + |start| tiktoken.encode_ordinary(&text[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + } } } diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index c1df5b6..873d7be 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -19,6 +19,11 @@ static BPE_CL100K: Lazy = Lazy::new(|| { rmp_serde::from_slice(bytes).expect("") }); +static BPE_O200K: Lazy = Lazy::new(|| { + let bytes = include_bytes!("data/bpe_o200k.dict"); + rmp_serde::from_slice(bytes).expect("") +}); + /// Representation of the byte pair dictionary. /// This struct provides various conversions. /// We put all of them into a single struct so that they can be reused by different implementations. @@ -153,11 +158,15 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) -> } fn hash_bytes(bytes: &[u8]) -> u32 { + hash_bytes_with_factor(bytes, 17846336922010275747) +} + +fn hash_bytes_with_factor(bytes: &[u8], factor: u64) -> u32 { let mut hasher = FnvHasher::default(); bytes.hash(&mut hasher); // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash. // To make them unique for the given tokens, we have to add unfortunately another multiplication. - ((hasher.finish().wrapping_mul(37493864257)) >> 32) as u32 + ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 } fn find_token_by_bytes( @@ -180,6 +189,10 @@ impl BytePairEncoding { &BPE_CL100K } + pub fn o200k() -> &'static Self { + &BPE_O200K + } + /// Construct a BytePairEncoding instance frmo a tiktoken dictionary. pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self { Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i]))) @@ -492,13 +505,11 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec { #[cfg(test)] mod tests { - use std::fs::File; - use std::path::PathBuf; + use std::time::Instant; use itertools::Itertools; - use serde::Serialize; - use tiktoken_rs::{cl100k_base, cl100k_base_singleton}; + use tiktoken_rs::cl100k_base_singleton; use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; @@ -541,19 +552,72 @@ mod tests { } } } +} - // TODO: Move the generation of the dictionary into some build procedure? +#[cfg(test)] +mod data { + use std::collections::HashSet; + use std::fs::File; + use std::path::PathBuf; + + use rand::Rng; + use serde::Serialize; + use tiktoken_rs::{cl100k_base, o200k_base}; + + use super::*; + + const BPE_CL100K_LEN: usize = 100256; + const BPE_O200K_LEN: usize = 199998; + + /// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions. + /// 1. Ensure all supported tokenizers are in the list. + /// 2. Update the hash factor in [`hash_bytes`]. + /// 3. Run [`update_token_dicts`] tests below to update data files. #[test] - fn test_serialize() { + #[ignore = "run manually to find a suitable hash factor"] + fn find_hash_factor() { + let bpes: &mut [(CoreBPE, usize)] = &mut [ + (cl100k_base().unwrap(), BPE_CL100K_LEN), + (o200k_base().unwrap(), BPE_O200K_LEN), + ]; + let mut rnd = rand::thread_rng(); + loop { + let factor: u64 = rnd.gen(); + if bpes.iter().all(|(bpe, len)| { + let mut seen = HashSet::with_capacity(*len); + (0..*len) + .all(|i| seen.insert(hash_bytes_with_factor(&bpe._decode_native(&[i]), factor))) + }) { + println!("hash factor: {factor}"); + return; + } + } + } + + #[test] + fn update_token_dicts() { + serialize_tokens( + &cl100k_base().expect("tiktoken initialization must not fail!"), + BPE_CL100K_LEN, + "cl100k", + ); + serialize_tokens( + &o200k_base().expect("tiktoken initialization must not fail!"), + BPE_O200K_LEN, + "o200k", + ); + } + + #[track_caller] + fn serialize_tokens(dict: &CoreBPE, num_tokens: usize, name: &str) { let path = PathBuf::from(file!()); let dir = path.parent().unwrap(); - let data_file = dir.join("data/bpe_cl100k.dict"); + let data_file = dir.join(format!("data/bpe_{name}.dict")); let current_dir = std::env::current_dir().unwrap(); let abs_path = current_dir.parent().unwrap().parent().unwrap(); let file = File::create(abs_path.join(data_file)).unwrap(); let mut serializer = rmp_serde::Serializer::new(file); - let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!"); - BytePairEncoding::from_tiktoken(&cl100_dict, 100256) + BytePairEncoding::from_tiktoken(dict, num_tokens) .serialize(&mut serializer) .unwrap(); } diff --git a/crates/bpe/src/data/bpe_cl100k.dict b/crates/bpe/src/data/bpe_cl100k.dict index d8bc6e5..c28cb32 100644 Binary files a/crates/bpe/src/data/bpe_cl100k.dict and b/crates/bpe/src/data/bpe_cl100k.dict differ diff --git a/crates/bpe/src/data/bpe_o200k.dict b/crates/bpe/src/data/bpe_o200k.dict new file mode 100644 index 0000000..88719e3 Binary files /dev/null and b/crates/bpe/src/data/bpe_o200k.dict differ diff --git a/crates/geo_filters/src/config/bitchunks.rs b/crates/geo_filters/src/config/bitchunks.rs index 937bcba..72660c7 100644 --- a/crates/geo_filters/src/config/bitchunks.rs +++ b/crates/geo_filters/src/config/bitchunks.rs @@ -217,9 +217,8 @@ pub(crate) fn count_ones_from_bitchunks( let mut total = take_ref(&mut ones, max_msb_len - 1).count(); let smallest_msb = ones .next() - .map(|bucket| { + .inspect(|_| { total += 1; - bucket }) .unwrap_or_default(); diff --git a/criterion.toml b/criterion.toml new file mode 100644 index 0000000..c0f42f2 --- /dev/null +++ b/criterion.toml @@ -0,0 +1,2 @@ +# save report in this directory, even if a custom target directory is set +criterion_home = "./target/criterion"