diff --git a/crates/bpe/README.md b/crates/bpe/README.md index 0dcb703..e90ceef 100644 --- a/crates/bpe/README.md +++ b/crates/bpe/README.md @@ -294,9 +294,8 @@ This suggests that pre-tokenization is not necessary from a performance perspect ![encoding runtime comparison](./images/performance-comparison.svg) -The graph below shows encoding results for input that is particularly challenging for tiktoken. -The input consists of random ranges taken from the continuous list of all Unicode code points excluding whitespace. -The performance of tiktoken shows a quadratic growth with the input size. +The graph below shows encoding results when the input cannot be split in pre-tokenization and allows a better comparison of pure BPE performance. +This case is particularly challenging for tiktoken, which shows a quadratic growth with the input size. The Huggingface encoder scales better, but becomes slower and slower compared to our implementation as input size increases. ![worst-case encoding runtime comparison](./images/performance-worstcase.svg) diff --git a/crates/bpe/benchmarks/performance.rs b/crates/bpe/benchmarks/performance.rs index b3b4d59..d192225 100644 --- a/crates/bpe/benchmarks/performance.rs +++ b/crates/bpe/benchmarks/performance.rs @@ -1,7 +1,9 @@ use std::time::Duration; use bpe::appendable_encoder::AppendableEncoder; -use bpe::byte_pair_encoding::{create_test_string, select_test_string}; +use bpe::byte_pair_encoding::{ + create_test_string, create_test_string_with_predicate, select_test_string, +}; use bpe::interval_encoding::IntervalEncoding; use bpe_benchmarks::*; use criterion::{ @@ -11,7 +13,7 @@ use rand::{thread_rng, Rng}; fn counting_benchmark(c: &mut Criterion) { for (name, bpe, _, _) in TOKENIZERS.iter() { - let input = create_test_string(&bpe.bpe, 80000); + let input = create_test_string(&bpe.bpe, 80_000); let fast = IntervalEncoding::new(&bpe.bpe, input.as_bytes()); let mut group = c.benchmark_group(format!("counting-{name}")); @@ -185,11 +187,13 @@ fn comparison_benchmark(c: &mut Criterion) { } fn worstcase_comparison_benchmark(c: &mut Criterion) { - for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { - let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect(); + for (name, tok, tiktoken, huggingface) in TOKENIZERS.iter() { + let text = create_test_string_with_predicate(&tok.bpe, 100000, |text| { + tok.split(text).nth(1).is_none() + }); let mut group = c.benchmark_group(format!("worstcase-{name}")); - for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] { + for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000] { group.throughput(criterion::Throughput::Bytes(bytes as u64)); group.bench_with_input( BenchmarkId::new("backtracking", bytes), @@ -197,7 +201,7 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { |b, bytes| { b.iter_batched( || select_test_string(&text, *bytes), - |text| bpe.encode(text), + |text| tok.encode(text), criterion::BatchSize::SmallInput, ) }, diff --git a/crates/bpe/images/performance-appending.svg b/crates/bpe/images/performance-appending.svg index f0d1b69..486fc5e 100644 --- a/crates/bpe/images/performance-appending.svg +++ b/crates/bpe/images/performance-appending.svg @@ -1,52 +1,64 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - + + + + + - - - - - - + + + + + + - - + + diff --git a/crates/bpe/images/performance-comparison.svg b/crates/bpe/images/performance-comparison.svg index a6c89f7..aa2ea36 100644 --- a/crates/bpe/images/performance-comparison.svg +++ b/crates/bpe/images/performance-comparison.svg @@ -1,58 +1,76 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - + + + + + - - - - - + + + + + - - - - - - + + + + + + - - + + diff --git a/crates/bpe/images/performance-counting.svg b/crates/bpe/images/performance-counting.svg index 2dff836..20b3e9d 100644 --- a/crates/bpe/images/performance-counting.svg +++ b/crates/bpe/images/performance-counting.svg @@ -1,48 +1,64 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - + + + + + - - - - - - + + + + + + - - + + diff --git a/crates/bpe/images/performance-encoding.svg b/crates/bpe/images/performance-encoding.svg index a45eec9..46f7d3c 100644 --- a/crates/bpe/images/performance-encoding.svg +++ b/crates/bpe/images/performance-encoding.svg @@ -1,76 +1,88 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - - + + + + + + - - + + diff --git a/crates/bpe/images/performance-worstcase.svg b/crates/bpe/images/performance-worstcase.svg index 132b8b3..21b3c61 100644 --- a/crates/bpe/images/performance-worstcase.svg +++ b/crates/bpe/images/performance-worstcase.svg @@ -1,91 +1,91 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - + + + + + + + + - - - - - - - - - - + + + + + + + + - - - - - - - - - - - + + + + + + + + + - - + + diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index 695e549..a90b320 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -553,18 +553,49 @@ impl BytePairEncoding { } } -/// Generate a test string by concatenating random tokens. +/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long. #[cfg(feature = "rand")] pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String { + create_test_string_with_predicate(bpe, min_bytes, |_| true) +} + +/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long. +/// The given predicate enforces other properties on the generated string. Note that this can hurt performance or +/// even cause non-termination! +#[cfg(feature = "rand")] +pub fn create_test_string_with_predicate( + bpe: &BytePairEncoding, + min_bytes: usize, + predicate: impl Fn(&str) -> bool, +) -> String { use rand::{thread_rng, Rng}; + // the string we accumulated thus far let mut result = String::new(); - while result.len() < min_bytes { - let i = thread_rng().gen_range(0..bpe.num_tokens()); - // We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's - // token set. The chance of constructing a valid UTF-8 character across a token boundary - // by picking random tokens is so small that it is unlikely to happen anyway. - if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i as u32)) { - result.push_str(token); + // the tokens we added so we can backtrack + let mut tokens = Vec::new(); + 'keep: while result.len() < min_bytes { + // try a few times to find a suitable token + 'next: for _ in 0..8 { + // pick a random token and provisionally add it + let i = thread_rng().gen_range(0..bpe.num_tokens()) as u32; + // We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's + // token set. The chance of constructing a valid UTF-8 character across a token boundary + // by picking random tokens is so small that it is unlikely to happen anyway. + if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i)) { + result.push_str(token); + } else { + continue 'next; + } + if predicate(&result) { + tokens.push(i); + continue 'keep; + } else { + result.truncate(result.len() - bpe.token_len(i)); + } + } + // we didn't find anything after a few tries, backtrack + if let Some(i) = tokens.pop() { + result.truncate(result.len() - bpe.token_len(i)); } } result diff --git a/crates/geo_filters/evaluation/accuracy.rs b/crates/geo_filters/evaluation/accuracy.rs index 2e2eed0..c3151b7 100644 --- a/crates/geo_filters/evaluation/accuracy.rs +++ b/crates/geo_filters/evaluation/accuracy.rs @@ -105,7 +105,8 @@ impl Accuracy { .config .iter() .map(|c| { - simulation_config_from_str(c).expect(&format!("not a valid configuration: {}", c)) + simulation_config_from_str(c) + .unwrap_or_else(|_| panic!("not a valid configuration: {}", c)) }) .collect_vec(); let set_sizes = if self.set_size.is_empty() { @@ -118,9 +119,10 @@ impl Accuracy { let mut output = self.output; output.set_extension("csv"); - let f = File::create(&output).expect(&format!("cannot create file: {}", output.display())); + let f = File::create(&output) + .unwrap_or_else(|_| panic!("cannot create file: {}", output.display())); write_simulation_results(&configs, &set_sizes, results, f) - .expect(&format!("cannot write file: {}", output.display())); + .unwrap_or_else(|_| panic!("cannot write file: {}", output.display())); println!(" csv file = {}", output.display()); println!(); } @@ -139,9 +141,9 @@ impl SimulationConfigParser { Self(Regex::new(re).expect(""), Arc::new(f)) } - fn parse<'a>(&self, name: &str) -> Option { + fn parse(&self, name: &str) -> Option { self.0 - .captures(&name) + .captures(name) .map(self.1.as_ref()) .map(|p| (name.to_string(), p)) } @@ -225,7 +227,11 @@ fn simulation_config_from_str(name: &str) -> Result { fn capture_usizes(c: &Captures, is: [usize; N]) -> [usize; N] { let mut values = [0; N]; for i in 0..is.len() { - values[i] = usize::from_str_radix(c.get(is[i]).expect("capture to exist").as_str(), 10) + values[i] = c + .get(is[i]) + .expect("capture to exist") + .as_str() + .parse::() .expect("number string"); } values