Skip to content

Commit

Permalink
Merge pull request #16 from github/add-o200k-tokens
Browse files Browse the repository at this point in the history
Add support for o200k tokenization
  • Loading branch information
hendrikvanantwerpen authored Sep 25, 2024
2 parents 2112a0e + 2a4deef commit 3a14632
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ Cargo.lock
/target/
/crates/*/target/
/crates/*/Cargo.lock
.vscode/
.vscode/
161 changes: 88 additions & 73 deletions crates/bpe/benches/counting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,85 +6,100 @@ use criterion::{criterion_group, criterion_main, Criterion};
use rand::{thread_rng, Rng};

fn counting_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::cl100k();
let text = create_test_bytes(&bpe, 20000);
for (name, bpe) in [
("cl100k", BytePairEncoding::cl100k()),
("o200k", BytePairEncoding::o200k()),
] {
let text = create_test_bytes(&bpe, 20000);
let fast = IntervalEncoding::new(&bpe, &text);

let fast = IntervalEncoding::new(&bpe, &text);

for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bytes-{bytes}"));
group.bench_function("hybrid counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| fast.count(start..start + bytes),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("backtrack counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.count(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}"));
group.bench_function("hybrid counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| fast.count(start..start + bytes),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("backtrack counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.count(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
}
}
}

fn encoding_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::cl100k();
let tiktoken = tiktoken_rs::cl100k_base().unwrap();
let text = create_test_string(&bpe, 20000);
let input = text.as_bytes();
for (name, bpe, tiktoken) in [
(
"cl100k",
BytePairEncoding::cl100k(),
tiktoken_rs::cl100k_base().unwrap(),
),
(
"o200k",
BytePairEncoding::o200k(),
tiktoken_rs::o200k_base().unwrap(),
),
] {
let text = create_test_string(&bpe, 20000);
let input = text.as_bytes();

for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bytes-{bytes}"));
group.bench_function("backtracking", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_backtracking(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("heap", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_bitfield(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("dynamic programming", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_table(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("greedy", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_greedy(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("minimal", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_minimal(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("tiktoken", |b| {
b.iter_batched(
|| loop {
let start = thread_rng().gen_range(0..input.len() - bytes - 1);
if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes]) {
return start;
}
},
|start| tiktoken.encode_ordinary(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}"));
group.bench_function("backtracking", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_backtracking(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("heap", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_bitfield(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("dynamic programming", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_table(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("greedy", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_greedy(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("minimal", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_minimal(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("tiktoken", |b| {
b.iter_batched(
|| loop {
let start = thread_rng().gen_range(0..input.len() - bytes - 1);
if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes])
{
return start;
}
},
|start| tiktoken.encode_ordinary(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
}
}
}

Expand Down
84 changes: 74 additions & 10 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ static BPE_CL100K: Lazy<BytePairEncoding> = Lazy::new(|| {
rmp_serde::from_slice(bytes).expect("")
});

static BPE_O200K: Lazy<BytePairEncoding> = Lazy::new(|| {
let bytes = include_bytes!("data/bpe_o200k.dict");
rmp_serde::from_slice(bytes).expect("")
});

/// Representation of the byte pair dictionary.
/// This struct provides various conversions.
/// We put all of them into a single struct so that they can be reused by different implementations.
Expand Down Expand Up @@ -153,11 +158,15 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) ->
}

fn hash_bytes(bytes: &[u8]) -> u32 {
hash_bytes_with_factor(bytes, 17846336922010275747)
}

fn hash_bytes_with_factor(bytes: &[u8], factor: u64) -> u32 {
let mut hasher = FnvHasher::default();
bytes.hash(&mut hasher);
// Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash.
// To make them unique for the given tokens, we have to add unfortunately another multiplication.
((hasher.finish().wrapping_mul(37493864257)) >> 32) as u32
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
}

fn find_token_by_bytes(
Expand All @@ -180,6 +189,10 @@ impl BytePairEncoding {
&BPE_CL100K
}

pub fn o200k() -> &'static Self {
&BPE_O200K
}

/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self {
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
Expand Down Expand Up @@ -492,13 +505,11 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {

#[cfg(test)]
mod tests {
use std::fs::File;
use std::path::PathBuf;

use std::time::Instant;

use itertools::Itertools;
use serde::Serialize;
use tiktoken_rs::{cl100k_base, cl100k_base_singleton};
use tiktoken_rs::cl100k_base_singleton;

use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};

Expand Down Expand Up @@ -541,19 +552,72 @@ mod tests {
}
}
}
}

// TODO: Move the generation of the dictionary into some build procedure?
#[cfg(test)]
mod data {
use std::collections::HashSet;
use std::fs::File;
use std::path::PathBuf;

use rand::Rng;
use serde::Serialize;
use tiktoken_rs::{cl100k_base, o200k_base};

use super::*;

const BPE_CL100K_LEN: usize = 100256;
const BPE_O200K_LEN: usize = 199998;

/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
/// 1. Ensure all supported tokenizers are in the list.
/// 2. Update the hash factor in [`hash_bytes`].
/// 3. Run [`update_token_dicts`] tests below to update data files.
#[test]
fn test_serialize() {
#[ignore = "run manually to find a suitable hash factor"]
fn find_hash_factor() {
let bpes: &mut [(CoreBPE, usize)] = &mut [
(cl100k_base().unwrap(), BPE_CL100K_LEN),
(o200k_base().unwrap(), BPE_O200K_LEN),
];
let mut rnd = rand::thread_rng();
loop {
let factor: u64 = rnd.gen();
if bpes.iter().all(|(bpe, len)| {
let mut seen = HashSet::with_capacity(*len);
(0..*len)
.all(|i| seen.insert(hash_bytes_with_factor(&bpe._decode_native(&[i]), factor)))
}) {
println!("hash factor: {factor}");
return;
}
}
}

#[test]
fn update_token_dicts() {
serialize_tokens(
&cl100k_base().expect("tiktoken initialization must not fail!"),
BPE_CL100K_LEN,
"cl100k",
);
serialize_tokens(
&o200k_base().expect("tiktoken initialization must not fail!"),
BPE_O200K_LEN,
"o200k",
);
}

#[track_caller]
fn serialize_tokens(dict: &CoreBPE, num_tokens: usize, name: &str) {
let path = PathBuf::from(file!());
let dir = path.parent().unwrap();
let data_file = dir.join("data/bpe_cl100k.dict");
let data_file = dir.join(format!("data/bpe_{name}.dict"));
let current_dir = std::env::current_dir().unwrap();
let abs_path = current_dir.parent().unwrap().parent().unwrap();
let file = File::create(abs_path.join(data_file)).unwrap();
let mut serializer = rmp_serde::Serializer::new(file);
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!");
BytePairEncoding::from_tiktoken(&cl100_dict, 100256)
BytePairEncoding::from_tiktoken(dict, num_tokens)
.serialize(&mut serializer)
.unwrap();
}
Expand Down
Binary file modified crates/bpe/src/data/bpe_cl100k.dict
Binary file not shown.
Binary file added crates/bpe/src/data/bpe_o200k.dict
Binary file not shown.
3 changes: 1 addition & 2 deletions crates/geo_filters/src/config/bitchunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ pub(crate) fn count_ones_from_bitchunks<T: IsBucketType>(
let mut total = take_ref(&mut ones, max_msb_len - 1).count();
let smallest_msb = ones
.next()
.map(|bucket| {
.inspect(|_| {
total += 1;
bucket
})
.unwrap_or_default();

Expand Down
2 changes: 2 additions & 0 deletions criterion.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# save report in this directory, even if a custom target directory is set
criterion_home = "./target/criterion"

0 comments on commit 3a14632

Please sign in to comment.