diff --git a/crates/bpe/Cargo.toml b/crates/bpe/Cargo.toml index 1d169d8..4050236 100644 --- a/crates/bpe/Cargo.toml +++ b/crates/bpe/Cargo.toml @@ -7,20 +7,24 @@ edition = "2021" crate-type = ["lib", "staticlib"] bench = false +[[bench]] +name = "counting" +path = "benches/counting.rs" +harness = false + +[features] +rand = ["dep:rand"] +tiktoken-rs = ["dep:tiktoken-rs"] + [dependencies] aneubeck-daachorse = "1.1.1" fnv = "1.0" itertools = "0.12" -once_cell = "1" -rand = "0.8" +rand = { version = "0.8", optional = true } rmp-serde = "1" serde = { version = "1", features = ["derive"] } -tiktoken-rs = "0.5" +tiktoken-rs = { version = "0.5", optional = true } [dev-dependencies] +bpe = { path = ".", features = ["rand", "tiktoken-rs"] } criterion = "0.5" - -[[bench]] -name = "counting" -path = "benches/counting.rs" -harness = false diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index 053bc0c..5e172d3 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -2,24 +2,23 @@ use std::cmp::Reverse; use std::collections::BinaryHeap; use std::hash::{Hash, Hasher}; use std::ops::Range; +use std::sync::LazyLock; use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; use fnv::{FnvHashMap, FnvHasher}; use itertools::Itertools; -use once_cell::sync::Lazy; use serde::de::Visitor; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use tiktoken_rs::CoreBPE; use crate::backtrack_encoder::BacktrackEncoder; use crate::bitfield::BitField; -static BPE_CL100K: Lazy = Lazy::new(|| { +static BPE_CL100K: LazyLock = LazyLock::new(|| { let bytes = include_bytes!("data/bpe_cl100k.dict"); rmp_serde::from_slice(bytes).expect("") }); -static BPE_O200K: Lazy = Lazy::new(|| { +static BPE_O200K: LazyLock = LazyLock::new(|| { let bytes = include_bytes!("data/bpe_o200k.dict"); rmp_serde::from_slice(bytes).expect("") }); @@ -194,7 +193,8 @@ impl BytePairEncoding { } /// Construct a BytePairEncoding instance frmo a tiktoken dictionary. - pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self { + #[cfg(feature = "tiktoken-rs")] + pub fn from_tiktoken(tiktoken_bpe: &tiktoken_rs::CoreBPE, num_tokens: usize) -> Self { Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i]))) } @@ -492,6 +492,7 @@ impl BytePairEncoding { } } +#[cfg(feature = "rand")] pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec { use rand::{thread_rng, Rng}; let mut text = vec![]; @@ -576,7 +577,7 @@ mod data { #[test] #[ignore = "run manually to find a suitable hash factor"] fn find_hash_factor() { - let bpes: &mut [(CoreBPE, usize)] = &mut [ + let bpes = &mut [ (cl100k_base().unwrap(), BPE_CL100K_LEN), (o200k_base().unwrap(), BPE_O200K_LEN), ]; @@ -609,7 +610,7 @@ mod data { } #[track_caller] - fn serialize_tokens(dict: &CoreBPE, num_tokens: usize, name: &str) { + fn serialize_tokens(dict: &tiktoken_rs::CoreBPE, num_tokens: usize, name: &str) { let path = PathBuf::from(file!()); let dir = path.parent().unwrap(); let data_file = dir.join(format!("data/bpe_{name}.dict"));