diff --git a/crates/bpe/README.md b/crates/bpe/README.md
index 0dcb703..e90ceef 100644
--- a/crates/bpe/README.md
+++ b/crates/bpe/README.md
@@ -294,9 +294,8 @@ This suggests that pre-tokenization is not necessary from a performance perspect

-The graph below shows encoding results for input that is particularly challenging for tiktoken.
-The input consists of random ranges taken from the continuous list of all Unicode code points excluding whitespace.
-The performance of tiktoken shows a quadratic growth with the input size.
+The graph below shows encoding results when the input cannot be split in pre-tokenization and allows a better comparison of pure BPE performance.
+This case is particularly challenging for tiktoken, which shows a quadratic growth with the input size.
The Huggingface encoder scales better, but becomes slower and slower compared to our implementation as input size increases.

diff --git a/crates/bpe/benchmarks/performance.rs b/crates/bpe/benchmarks/performance.rs
index b3b4d59..d192225 100644
--- a/crates/bpe/benchmarks/performance.rs
+++ b/crates/bpe/benchmarks/performance.rs
@@ -1,7 +1,9 @@
use std::time::Duration;
use bpe::appendable_encoder::AppendableEncoder;
-use bpe::byte_pair_encoding::{create_test_string, select_test_string};
+use bpe::byte_pair_encoding::{
+ create_test_string, create_test_string_with_predicate, select_test_string,
+};
use bpe::interval_encoding::IntervalEncoding;
use bpe_benchmarks::*;
use criterion::{
@@ -11,7 +13,7 @@ use rand::{thread_rng, Rng};
fn counting_benchmark(c: &mut Criterion) {
for (name, bpe, _, _) in TOKENIZERS.iter() {
- let input = create_test_string(&bpe.bpe, 80000);
+ let input = create_test_string(&bpe.bpe, 80_000);
let fast = IntervalEncoding::new(&bpe.bpe, input.as_bytes());
let mut group = c.benchmark_group(format!("counting-{name}"));
@@ -185,11 +187,13 @@ fn comparison_benchmark(c: &mut Criterion) {
}
fn worstcase_comparison_benchmark(c: &mut Criterion) {
- for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() {
- let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect();
+ for (name, tok, tiktoken, huggingface) in TOKENIZERS.iter() {
+ let text = create_test_string_with_predicate(&tok.bpe, 100000, |text| {
+ tok.split(text).nth(1).is_none()
+ });
let mut group = c.benchmark_group(format!("worstcase-{name}"));
- for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] {
+ for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000] {
group.throughput(criterion::Throughput::Bytes(bytes as u64));
group.bench_with_input(
BenchmarkId::new("backtracking", bytes),
@@ -197,7 +201,7 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) {
|b, bytes| {
b.iter_batched(
|| select_test_string(&text, *bytes),
- |text| bpe.encode(text),
+ |text| tok.encode(text),
criterion::BatchSize::SmallInput,
)
},
diff --git a/crates/bpe/images/performance-appending.svg b/crates/bpe/images/performance-appending.svg
index f0d1b69..486fc5e 100644
--- a/crates/bpe/images/performance-appending.svg
+++ b/crates/bpe/images/performance-appending.svg
@@ -1,52 +1,64 @@
diff --git a/crates/bpe/images/performance-comparison.svg b/crates/bpe/images/performance-comparison.svg
index a6c89f7..aa2ea36 100644
--- a/crates/bpe/images/performance-comparison.svg
+++ b/crates/bpe/images/performance-comparison.svg
@@ -1,58 +1,76 @@
diff --git a/crates/bpe/images/performance-counting.svg b/crates/bpe/images/performance-counting.svg
index 2dff836..20b3e9d 100644
--- a/crates/bpe/images/performance-counting.svg
+++ b/crates/bpe/images/performance-counting.svg
@@ -1,48 +1,64 @@
diff --git a/crates/bpe/images/performance-encoding.svg b/crates/bpe/images/performance-encoding.svg
index a45eec9..46f7d3c 100644
--- a/crates/bpe/images/performance-encoding.svg
+++ b/crates/bpe/images/performance-encoding.svg
@@ -1,76 +1,88 @@
diff --git a/crates/bpe/images/performance-worstcase.svg b/crates/bpe/images/performance-worstcase.svg
index 132b8b3..21b3c61 100644
--- a/crates/bpe/images/performance-worstcase.svg
+++ b/crates/bpe/images/performance-worstcase.svg
@@ -1,91 +1,91 @@
diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs
index 695e549..a90b320 100644
--- a/crates/bpe/src/byte_pair_encoding.rs
+++ b/crates/bpe/src/byte_pair_encoding.rs
@@ -553,18 +553,49 @@ impl BytePairEncoding {
}
}
-/// Generate a test string by concatenating random tokens.
+/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
#[cfg(feature = "rand")]
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
+ create_test_string_with_predicate(bpe, min_bytes, |_| true)
+}
+
+/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
+/// The given predicate enforces other properties on the generated string. Note that this can hurt performance or
+/// even cause non-termination!
+#[cfg(feature = "rand")]
+pub fn create_test_string_with_predicate(
+ bpe: &BytePairEncoding,
+ min_bytes: usize,
+ predicate: impl Fn(&str) -> bool,
+) -> String {
use rand::{thread_rng, Rng};
+ // the string we accumulated thus far
let mut result = String::new();
- while result.len() < min_bytes {
- let i = thread_rng().gen_range(0..bpe.num_tokens());
- // We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
- // token set. The chance of constructing a valid UTF-8 character across a token boundary
- // by picking random tokens is so small that it is unlikely to happen anyway.
- if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i as u32)) {
- result.push_str(token);
+ // the tokens we added so we can backtrack
+ let mut tokens = Vec::new();
+ 'keep: while result.len() < min_bytes {
+ // try a few times to find a suitable token
+ 'next: for _ in 0..8 {
+ // pick a random token and provisionally add it
+ let i = thread_rng().gen_range(0..bpe.num_tokens()) as u32;
+ // We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
+ // token set. The chance of constructing a valid UTF-8 character across a token boundary
+ // by picking random tokens is so small that it is unlikely to happen anyway.
+ if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i)) {
+ result.push_str(token);
+ } else {
+ continue 'next;
+ }
+ if predicate(&result) {
+ tokens.push(i);
+ continue 'keep;
+ } else {
+ result.truncate(result.len() - bpe.token_len(i));
+ }
+ }
+ // we didn't find anything after a few tries, backtrack
+ if let Some(i) = tokens.pop() {
+ result.truncate(result.len() - bpe.token_len(i));
}
}
result
diff --git a/crates/geo_filters/evaluation/accuracy.rs b/crates/geo_filters/evaluation/accuracy.rs
index 2e2eed0..c3151b7 100644
--- a/crates/geo_filters/evaluation/accuracy.rs
+++ b/crates/geo_filters/evaluation/accuracy.rs
@@ -105,7 +105,8 @@ impl Accuracy {
.config
.iter()
.map(|c| {
- simulation_config_from_str(c).expect(&format!("not a valid configuration: {}", c))
+ simulation_config_from_str(c)
+ .unwrap_or_else(|_| panic!("not a valid configuration: {}", c))
})
.collect_vec();
let set_sizes = if self.set_size.is_empty() {
@@ -118,9 +119,10 @@ impl Accuracy {
let mut output = self.output;
output.set_extension("csv");
- let f = File::create(&output).expect(&format!("cannot create file: {}", output.display()));
+ let f = File::create(&output)
+ .unwrap_or_else(|_| panic!("cannot create file: {}", output.display()));
write_simulation_results(&configs, &set_sizes, results, f)
- .expect(&format!("cannot write file: {}", output.display()));
+ .unwrap_or_else(|_| panic!("cannot write file: {}", output.display()));
println!(" csv file = {}", output.display());
println!();
}
@@ -139,9 +141,9 @@ impl SimulationConfigParser {
Self(Regex::new(re).expect(""), Arc::new(f))
}
- fn parse<'a>(&self, name: &str) -> Option {
+ fn parse(&self, name: &str) -> Option {
self.0
- .captures(&name)
+ .captures(name)
.map(self.1.as_ref())
.map(|p| (name.to_string(), p))
}
@@ -225,7 +227,11 @@ fn simulation_config_from_str(name: &str) -> Result {
fn capture_usizes(c: &Captures, is: [usize; N]) -> [usize; N] {
let mut values = [0; N];
for i in 0..is.len() {
- values[i] = usize::from_str_radix(c.get(is[i]).expect("capture to exist").as_str(), 10)
+ values[i] = c
+ .get(is[i])
+ .expect("capture to exist")
+ .as_str()
+ .parse::()
.expect("number string");
}
values