Skip to content

Commit

Permalink
Move tiktoken data reading to bpe
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Oct 14, 2024
1 parent 3132551 commit 348abf0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 34 deletions.
2 changes: 1 addition & 1 deletion crates/bpe-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ tiktoken-rs = "0.6"

[build-dependencies]
base64 = "0.22.1"
bpe = { version = "0.1.0", path = "../bpe" }
bpe = { version = "0.1.0", path = "../bpe", features = ["tiktoken"] }
flate2 = "1.0"
rmp-serde = "1"
serde = { version = "1" }
45 changes: 12 additions & 33 deletions crates/bpe-openai/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,30 @@ use std::fs::File;
use std::io::Read;
use std::path::PathBuf;

use base64::prelude::*;
use bpe::byte_pair_encoding::BytePairEncoding;
use bpe::byte_pair_encoding::{read_tiktoken, BytePairEncoding};
use serde::Serialize;

fn main() {
serialize_tokens(
"r50k_base",
load_tiktoken_gz(include_bytes!("data/r50k_base.tiktoken.gz")),
1,
);
serialize_tokens(
"p50k_base",
load_tiktoken_gz(include_bytes!("data/p50k_base.tiktoken.gz")),
1,
);
serialize_tokens(
serialize_tiktoken_bpe("r50k_base", include_bytes!("data/r50k_base.tiktoken.gz"), 1);
serialize_tiktoken_bpe("p50k_base", include_bytes!("data/p50k_base.tiktoken.gz"), 1);
serialize_tiktoken_bpe(
"cl100k_base",
load_tiktoken_gz(include_bytes!("data/cl100k_base.tiktoken.gz")),
include_bytes!("data/cl100k_base.tiktoken.gz"),
17846336922010275747,
);
serialize_tokens(
serialize_tiktoken_bpe(
"o200k_base",
load_tiktoken_gz(include_bytes!("data/o200k_base.tiktoken.gz")),
include_bytes!("data/o200k_base.tiktoken.gz"),
17846336922010275747,
);
println!("cargo::rerun-if-changed=build.rs");
}

fn serialize_tokens(name: &str, tokens: Vec<Vec<u8>>, hash_factor: u64) {
fn serialize_tiktoken_bpe(name: &str, data: &[u8], hash_factor: u64) {
let mut dec = flate2::read::GzDecoder::new(data);
let mut tiktoken = String::new();
dec.read_to_string(&mut tiktoken).expect("can decode data");
let tokens = read_tiktoken(&tiktoken).expect("can read data");
let mut path = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is set during build"));
path.push(format!("bpe_{name}.dict"));
let file = File::create(path).expect("can create output file");
Expand All @@ -40,19 +35,3 @@ fn serialize_tokens(name: &str, tokens: Vec<Vec<u8>>, hash_factor: u64) {
bpe.serialize(&mut serializer)
.expect("serialization succeeds");
}

fn load_tiktoken_gz(data: &[u8]) -> Vec<Vec<u8>> {
let mut dec = flate2::read::GzDecoder::new(data);
let mut tiktoken = String::new();
dec.read_to_string(&mut tiktoken).expect("can decode data");
let tokens: Vec<_> = tiktoken
.lines()
.filter(|line| !line.is_empty())
.map(|line| {
BASE64_STANDARD
.decode(line.split_whitespace().next().expect("token field on line"))
.expect("base64 token field")
})
.collect();
tokens
}
2 changes: 2 additions & 0 deletions crates/bpe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ bench = false

[features]
rand = ["dep:rand"]
tiktoken = ["dep:base64"]

[dependencies]
aneubeck-daachorse = "1.1.1"
base64 = { version = "0.22", optional = true }
fnv = "1.0"
itertools = "0.12"
rand = { version = "0.8", optional = true }
Expand Down
38 changes: 38 additions & 0 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
}

/// Find a suitable hash factor for the given tiktoken data that prevents collisions when
/// constructing a [`BytePairEncoding`] from those tokens.
#[cfg(all(feature = "rand", feature = "tiktoken"))]
pub fn find_hash_factor_for_tiktoken(data: &str) -> Result<u64, base64::DecodeError> {
Ok(find_hash_factor_for_dictionary(read_tiktoken(data)?))
}

/// Find a suitable hash factor for a set of given tokens that prevents collisions when
/// constructing a [`BytePairEncoding`] from those tokens.
#[cfg(feature = "rand")]
Expand Down Expand Up @@ -193,7 +200,38 @@ fn find_token_by_bytes(
}
}

/// Read the tokens from a tiktoken data file, which contains base64 encoded tokens at
/// the start of each line, in descending frequency order.
#[cfg(feature = "tiktoken")]
pub fn read_tiktoken(data: &str) -> Result<Vec<Vec<u8>>, base64::DecodeError> {
use base64::prelude::*;
data.lines()
.filter(|line| !line.is_empty())
.map(|line| {
let encoded_token = line
.split_whitespace()
.next()
.expect("non-empty line has first field");
BASE64_STANDARD.decode(encoded_token)
})
.try_collect()
}

impl BytePairEncoding {
/// Construct a BytePairEncoding instance from a tiktoken data file.
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
/// found using [`find_hash_factor_for_tiktoken`].
///
/// The recommended approach is to store the serialized value and reuse that,
/// to prevent repeating the cost of computing the hash factor and encoding.
#[cfg(feature = "tiktoken")]
pub fn from_tiktoken(
data: &str,
hash_factor: Option<u64>,
) -> Result<Self, base64::DecodeError> {
Ok(Self::from_dictionary(read_tiktoken(data)?, hash_factor))
}

/// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
/// found using [`find_hash_factor_for_dictionary`].
Expand Down

0 comments on commit 348abf0

Please sign in to comment.