diff --git a/crates/bpe-openai/Cargo.toml b/crates/bpe-openai/Cargo.toml index 5736ebd..7d3b3aa 100644 --- a/crates/bpe-openai/Cargo.toml +++ b/crates/bpe-openai/Cargo.toml @@ -24,7 +24,7 @@ tiktoken-rs = "0.6" [build-dependencies] base64 = "0.22.1" -bpe = { version = "0.1.0", path = "../bpe" } +bpe = { version = "0.1.0", path = "../bpe", features = ["tiktoken"] } flate2 = "1.0" rmp-serde = "1" serde = { version = "1" } diff --git a/crates/bpe-openai/build.rs b/crates/bpe-openai/build.rs index 28b20e7..472e580 100644 --- a/crates/bpe-openai/build.rs +++ b/crates/bpe-openai/build.rs @@ -3,35 +3,30 @@ use std::fs::File; use std::io::Read; use std::path::PathBuf; -use base64::prelude::*; -use bpe::byte_pair_encoding::BytePairEncoding; +use bpe::byte_pair_encoding::{read_tiktoken, BytePairEncoding}; use serde::Serialize; fn main() { - serialize_tokens( - "r50k_base", - load_tiktoken_gz(include_bytes!("data/r50k_base.tiktoken.gz")), - 1, - ); - serialize_tokens( - "p50k_base", - load_tiktoken_gz(include_bytes!("data/p50k_base.tiktoken.gz")), - 1, - ); - serialize_tokens( + serialize_tiktoken_bpe("r50k_base", include_bytes!("data/r50k_base.tiktoken.gz"), 1); + serialize_tiktoken_bpe("p50k_base", include_bytes!("data/p50k_base.tiktoken.gz"), 1); + serialize_tiktoken_bpe( "cl100k_base", - load_tiktoken_gz(include_bytes!("data/cl100k_base.tiktoken.gz")), + include_bytes!("data/cl100k_base.tiktoken.gz"), 17846336922010275747, ); - serialize_tokens( + serialize_tiktoken_bpe( "o200k_base", - load_tiktoken_gz(include_bytes!("data/o200k_base.tiktoken.gz")), + include_bytes!("data/o200k_base.tiktoken.gz"), 17846336922010275747, ); println!("cargo::rerun-if-changed=build.rs"); } -fn serialize_tokens(name: &str, tokens: Vec>, hash_factor: u64) { +fn serialize_tiktoken_bpe(name: &str, data: &[u8], hash_factor: u64) { + let mut dec = flate2::read::GzDecoder::new(data); + let mut tiktoken = String::new(); + dec.read_to_string(&mut tiktoken).expect("can decode data"); + let tokens = read_tiktoken(&tiktoken).expect("can read data"); let mut path = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is set during build")); path.push(format!("bpe_{name}.dict")); let file = File::create(path).expect("can create output file"); @@ -40,19 +35,3 @@ fn serialize_tokens(name: &str, tokens: Vec>, hash_factor: u64) { bpe.serialize(&mut serializer) .expect("serialization succeeds"); } - -fn load_tiktoken_gz(data: &[u8]) -> Vec> { - let mut dec = flate2::read::GzDecoder::new(data); - let mut tiktoken = String::new(); - dec.read_to_string(&mut tiktoken).expect("can decode data"); - let tokens: Vec<_> = tiktoken - .lines() - .filter(|line| !line.is_empty()) - .map(|line| { - BASE64_STANDARD - .decode(line.split_whitespace().next().expect("token field on line")) - .expect("base64 token field") - }) - .collect(); - tokens -} diff --git a/crates/bpe/Cargo.toml b/crates/bpe/Cargo.toml index a315292..a0a68f8 100644 --- a/crates/bpe/Cargo.toml +++ b/crates/bpe/Cargo.toml @@ -14,9 +14,11 @@ bench = false [features] rand = ["dep:rand"] +tiktoken = ["dep:base64"] [dependencies] aneubeck-daachorse = "1.1.1" +base64 = { version = "0.22", optional = true } fnv = "1.0" itertools = "0.12" rand = { version = "0.8", optional = true } diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index 23a6adc..9efbb0e 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -155,6 +155,13 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 } +/// Find a suitable hash factor for the given tiktoken data that prevents collisions when +/// constructing a [`BytePairEncoding`] from those tokens. +#[cfg(all(feature = "rand", feature = "tiktoken"))] +pub fn find_hash_factor_for_tiktoken(data: &str) -> Result { + Ok(find_hash_factor_for_dictionary(read_tiktoken(data)?)) +} + /// Find a suitable hash factor for a set of given tokens that prevents collisions when /// constructing a [`BytePairEncoding`] from those tokens. #[cfg(feature = "rand")] @@ -193,7 +200,38 @@ fn find_token_by_bytes( } } +/// Read the tokens from a tiktoken data file, which contains base64 encoded tokens at +/// the start of each line, in descending frequency order. +#[cfg(feature = "tiktoken")] +pub fn read_tiktoken(data: &str) -> Result>, base64::DecodeError> { + use base64::prelude::*; + data.lines() + .filter(|line| !line.is_empty()) + .map(|line| { + let encoded_token = line + .split_whitespace() + .next() + .expect("non-empty line has first field"); + BASE64_STANDARD.decode(encoded_token) + }) + .try_collect() +} + impl BytePairEncoding { + /// Construct a BytePairEncoding instance from a tiktoken data file. + /// A suitable hash factor may be necessary to prevent hash collisions, which can be + /// found using [`find_hash_factor_for_tiktoken`]. + /// + /// The recommended approach is to store the serialized value and reuse that, + /// to prevent repeating the cost of computing the hash factor and encoding. + #[cfg(feature = "tiktoken")] + pub fn from_tiktoken( + data: &str, + hash_factor: Option, + ) -> Result { + Ok(Self::from_dictionary(read_tiktoken(data)?, hash_factor)) + } + /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens. /// A suitable hash factor may be necessary to prevent hash collisions, which can be /// found using [`find_hash_factor_for_dictionary`].