Skip to content

Generate non-splittable test strings for worstcase benchmark #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions crates/bpe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,8 @@ This suggests that pre-tokenization is not necessary from a performance perspect

![encoding runtime comparison](./images/performance-comparison.svg)

The graph below shows encoding results for input that is particularly challenging for tiktoken.
The input consists of random ranges taken from the continuous list of all Unicode code points excluding whitespace.
The performance of tiktoken shows a quadratic growth with the input size.
The graph below shows encoding results when the input cannot be split in pre-tokenization and allows a better comparison of pure BPE performance.
This case is particularly challenging for tiktoken, which shows a quadratic growth with the input size.
The Huggingface encoder scales better, but becomes slower and slower compared to our implementation as input size increases.

![worst-case encoding runtime comparison](./images/performance-worstcase.svg)
16 changes: 10 additions & 6 deletions crates/bpe/benchmarks/performance.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::time::Duration;

use bpe::appendable_encoder::AppendableEncoder;
use bpe::byte_pair_encoding::{create_test_string, select_test_string};
use bpe::byte_pair_encoding::{
create_test_string, create_test_string_with_predicate, select_test_string,
};
use bpe::interval_encoding::IntervalEncoding;
use bpe_benchmarks::*;
use criterion::{
Expand All @@ -11,7 +13,7 @@ use rand::{thread_rng, Rng};

fn counting_benchmark(c: &mut Criterion) {
for (name, bpe, _, _) in TOKENIZERS.iter() {
let input = create_test_string(&bpe.bpe, 80000);
let input = create_test_string(&bpe.bpe, 80_000);
let fast = IntervalEncoding::new(&bpe.bpe, input.as_bytes());

let mut group = c.benchmark_group(format!("counting-{name}"));
Expand Down Expand Up @@ -185,19 +187,21 @@ fn comparison_benchmark(c: &mut Criterion) {
}

fn worstcase_comparison_benchmark(c: &mut Criterion) {
for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() {
let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect();
for (name, tok, tiktoken, huggingface) in TOKENIZERS.iter() {
let text = create_test_string_with_predicate(&tok.bpe, 100000, |text| {
tok.split(text).nth(1).is_none()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this procedure has one downside. You pick ONE pattern with the first couple of bytes. But these different patterns probably have somewhat different performance. In particular the couple of apostrophe patterns are only up to four characters long :) So, if you are unlucky and pick one those, then the whole string construction will result in an endless loop...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that can happen like that. The predicate is applied to the full text every time it is extended. So if you picked the token 're as the first token, it will be fine until you add anything else, because then it'll split after 're. The backtracking in create_test_string_with_predicate will give up adding anything after 're after a fixed number of tries, remove it, and pick another first token.

});

let mut group = c.benchmark_group(format!("worstcase-{name}"));
for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] {
for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000] {
group.throughput(criterion::Throughput::Bytes(bytes as u64));
group.bench_with_input(
BenchmarkId::new("backtracking", bytes),
&bytes,
|b, bytes| {
b.iter_batched(
|| select_test_string(&text, *bytes),
|text| bpe.encode(text),
|text| tok.encode(text),
criterion::BatchSize::SmallInput,
)
},
Expand Down
100 changes: 56 additions & 44 deletions crates/bpe/images/performance-appending.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
116 changes: 67 additions & 49 deletions crates/bpe/images/performance-comparison.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 56 additions & 40 deletions crates/bpe/images/performance-counting.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
140 changes: 76 additions & 64 deletions crates/bpe/images/performance-encoding.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
164 changes: 82 additions & 82 deletions crates/bpe/images/performance-worstcase.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 39 additions & 8 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,49 @@ impl BytePairEncoding {
}
}

/// Generate a test string by concatenating random tokens.
/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
#[cfg(feature = "rand")]
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
create_test_string_with_predicate(bpe, min_bytes, |_| true)
}

/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
/// The given predicate enforces other properties on the generated string. Note that this can hurt performance or
/// even cause non-termination!
#[cfg(feature = "rand")]
pub fn create_test_string_with_predicate(
bpe: &BytePairEncoding,
min_bytes: usize,
predicate: impl Fn(&str) -> bool,
) -> String {
use rand::{thread_rng, Rng};
// the string we accumulated thus far
let mut result = String::new();
while result.len() < min_bytes {
let i = thread_rng().gen_range(0..bpe.num_tokens());
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
// token set. The chance of constructing a valid UTF-8 character across a token boundary
// by picking random tokens is so small that it is unlikely to happen anyway.
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i as u32)) {
result.push_str(token);
// the tokens we added so we can backtrack
let mut tokens = Vec::new();
'keep: while result.len() < min_bytes {
// try a few times to find a suitable token
'next: for _ in 0..8 {
// pick a random token and provisionally add it
let i = thread_rng().gen_range(0..bpe.num_tokens()) as u32;
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
// token set. The chance of constructing a valid UTF-8 character across a token boundary
// by picking random tokens is so small that it is unlikely to happen anyway.
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i)) {
result.push_str(token);
} else {
continue 'next;
}
if predicate(&result) {
tokens.push(i);
continue 'keep;
} else {
result.truncate(result.len() - bpe.token_len(i));
}
}
// we didn't find anything after a few tries, backtrack
if let Some(i) = tokens.pop() {
result.truncate(result.len() - bpe.token_len(i));
}
}
result
Expand Down
18 changes: 12 additions & 6 deletions crates/geo_filters/evaluation/accuracy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ impl Accuracy {
.config
.iter()
.map(|c| {
simulation_config_from_str(c).expect(&format!("not a valid configuration: {}", c))
simulation_config_from_str(c)
.unwrap_or_else(|_| panic!("not a valid configuration: {}", c))
})
.collect_vec();
let set_sizes = if self.set_size.is_empty() {
Expand All @@ -118,9 +119,10 @@ impl Accuracy {

let mut output = self.output;
output.set_extension("csv");
let f = File::create(&output).expect(&format!("cannot create file: {}", output.display()));
let f = File::create(&output)
.unwrap_or_else(|_| panic!("cannot create file: {}", output.display()));
write_simulation_results(&configs, &set_sizes, results, f)
.expect(&format!("cannot write file: {}", output.display()));
.unwrap_or_else(|_| panic!("cannot write file: {}", output.display()));
println!(" csv file = {}", output.display());
println!();
}
Expand All @@ -139,9 +141,9 @@ impl SimulationConfigParser {
Self(Regex::new(re).expect(""), Arc::new(f))
}

fn parse<'a>(&self, name: &str) -> Option<SimulationConfig> {
fn parse(&self, name: &str) -> Option<SimulationConfig> {
self.0
.captures(&name)
.captures(name)
.map(self.1.as_ref())
.map(|p| (name.to_string(), p))
}
Expand Down Expand Up @@ -225,7 +227,11 @@ fn simulation_config_from_str(name: &str) -> Result<SimulationConfig, String> {
fn capture_usizes<const N: usize>(c: &Captures, is: [usize; N]) -> [usize; N] {
let mut values = [0; N];
for i in 0..is.len() {
values[i] = usize::from_str_radix(c.get(is[i]).expect("capture to exist").as_str(), 10)
values[i] = c
.get(is[i])
.expect("capture to exist")
.as_str()
.parse::<usize>()
.expect("number string");
}
values
Expand Down