Skip to content

Commit

Permalink
Make hash factor functions available to users
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Sep 30, 2024
1 parent 8c574d5 commit 646deeb
Showing 1 changed file with 43 additions and 37 deletions.
80 changes: 43 additions & 37 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,32 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
}

/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
/// when constructing a [`BytePairEncoding`] from those tokens.
#[cfg(all(feature = "tiktoken-rs", feature = "rand"))]
pub fn find_hash_factor_from_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 {
find_hash_factor(|i| bpe._decode_native(&[i]), len)
}

/// Find a suitable hash factor for a set of given tokens that prevents collisions when
/// constructing a [`BytePairEncoding`] from those tokens.
#[cfg(feature = "rand")]
pub fn find_hash_factor(tokens: impl Fn(usize) -> Vec<u8>, len: usize) -> u64 {
use std::collections::HashSet;

use rand::Rng;

let mut rnd = rand::thread_rng();
loop {
let factor: u64 = rnd.gen();
let mut seen = HashSet::with_capacity(len);
if (0..len).all(|i| seen.insert(hash_bytes(&tokens(i), factor))) {
println!("hash factor: {factor}");
return factor;
}
}
}

fn find_token_by_bytes(
all_tokens: &[u8],
token_starts: &[u32],
Expand All @@ -191,8 +217,12 @@ impl BytePairEncoding {
&BPE_O200K
}

/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
/// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
/// Construct a BytePairEncoding instance from a tiktoken dictionary.
/// A suitable hash factor may be necessary to prevent hash collisions,
/// which can by found using [`crate::data::find_hash_factor_from_tiktoken`].
///
/// The recommended approach is to store the serialized value and reuse that,
/// to prevent repeating the cost of computing the hash factor and encoding.
#[cfg(feature = "tiktoken-rs")]
pub fn from_tiktoken(
tiktoken_bpe: &tiktoken_rs::CoreBPE,
Expand All @@ -205,8 +235,12 @@ impl BytePairEncoding {
)
}

/// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
/// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
/// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
/// found using [`crate::data::find_hash_factor`].
///
/// The recommended approach is to store the serialized value and reuse that,
/// to prevent repeating the cost of computing the hash factor and encoding.
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>, hash_factor: Option<u64>) -> Self {
let hash_factor = hash_factor
.inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero"))
Expand Down Expand Up @@ -574,53 +608,25 @@ mod tests {

#[cfg(test)]
mod data {
use std::collections::HashSet;
use std::fs::File;
use std::path::PathBuf;

use rand::Rng;
use serde::Serialize;
use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};

use super::*;

const BPE_CL100K_LEN: usize = 100256;
const BPE_O200K_LEN: usize = 199998;

/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
/// 1. Set the `(bpe, len)` value to the tiktoken tokenizer you want to find a hash factor for.
/// 2. Update the hash factor in [`hash_bytes`].
/// 3. Run [`update_token_dicts`] tests below to update data files.
/// Note: If you forget this, the next test run will update the files, but
/// all other tests might fail because the data was not up-to-date.
#[test]
#[ignore = "run manually to find a suitable hash factor"]
#[allow(unreachable_code, unused_variables)]
fn find_hash_factor() {
let (bpe, len): (CoreBPE, _) = todo!("replace with BPE instance and token count");
let mut rnd = rand::thread_rng();
loop {
let factor: u64 = rnd.gen();
let mut seen = HashSet::with_capacity(len);
if (0..len).all(|i| seen.insert(hash_bytes(&bpe._decode_native(&[i]), factor))) {
println!("hash factor: {factor}");
return;
}
}
}
use crate::byte_pair_encoding::BytePairEncoding;

#[test]
fn update_token_dicts() {
serialize_tokens(
"cl100k",
&cl100k_base().expect("tiktoken initialization must not fail!"),
BPE_CL100K_LEN,
&tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"),
100256,
17846336922010275747,
);
serialize_tokens(
"o200k",
&o200k_base().expect("tiktoken initialization must not fail!"),
BPE_O200K_LEN,
&tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"),
199998,
17846336922010275747,
);
}
Expand Down

0 comments on commit 646deeb

Please sign in to comment.