Skip to content

Commit

Permalink
Merge branch 'move-equivalence-tests' into unsplittable-test-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Oct 21, 2024
2 parents 5b7d913 + b42989e commit 0cb520e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 99 deletions.
2 changes: 1 addition & 1 deletion crates/bpe-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ regex-automata = "0.4"
rmp-serde = "1"

[dev-dependencies]
tiktoken-rs = "0.6"
bpe = { version = "0.1.0", path = "../bpe", features = ["rand"] }
tiktoken-rs = "0.6"

[build-dependencies]
base64 = "0.22.1"
Expand Down
70 changes: 29 additions & 41 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,17 +553,6 @@ impl BytePairEncoding {
}
}

#[cfg(feature = "rand")]
fn is_char_boundary(b: u8) -> bool {
// Single byte encodings satisfy the bit pattern 0xxxxxxx, i.e. b < 128
// Continuation bytes satisfy the bit pattern 10xxxxxx, i.e. b < 192
// The rest are bytes belonging to the first byte of multi byte encodings (11xxxxxx): b >= 192
// When interpreting the byte representation as signed integers, then numbers in the range 128..192
// correspond to the smallest representable numbers. I.e. the two ranges [0, 128) and [192, 256) can
// be tested with a single signed comparison.
b as i8 >= -0x40 // NB: b < 128 || b >= 192
}

/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
#[cfg(feature = "rand")]
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
Expand All @@ -580,49 +569,36 @@ pub fn create_test_string_with_predicate(
predicate: impl Fn(&str) -> bool,
) -> String {
use rand::{thread_rng, Rng};
// the bytes we accumulated thus far
let mut bytes = Vec::new();
// the string we accumulated thus far
let mut result = String::new();
// the tokens we added so we can backtrack
let mut tokens = Vec::new();
// the number of valid UTF-8 bytes
let mut valid_bytes = 0;
'keep: while valid_bytes < min_bytes {
'keep: while result.len() < min_bytes {
// try a few times to find a suitable token
for _ in 0..8 {
'next: for _ in 0..8 {
// pick a random token and provisionally add it
let i = thread_rng().gen_range(0..bpe.num_tokens());
bytes.extend(bpe.token_bytes(i as u32));
// test if the additional bytes are valid utf-8
// the last character is not included, because it may be incomplete
let last = bytes
.iter()
.rev()
.find_position(|b| is_char_boundary(**b))
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
assert!(last >= valid_bytes);
if std::str::from_utf8(&bytes[valid_bytes..last]).is_ok()
&& predicate(std::str::from_utf8(&bytes[0..last]).expect("should be valid"))
{
let i = thread_rng().gen_range(0..bpe.num_tokens()) as u32;
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
// token set. The chance of constructing a valid UTF-8 character across a token boundary
// by picking random tokens is so small that it is unlikely to happen anyway.
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i)) {
result.push_str(token);
} else {
continue 'next;
}
if predicate(&result) {
tokens.push(i);
valid_bytes = last;
continue 'keep;
} else {
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
result.truncate(result.len() - bpe.token_len(i));
}
}
// we didn't find anything after a few tries, backtrack
if let Some(i) = tokens.pop() {
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
valid_bytes = bytes
.iter()
.rev()
.find_position(|b| is_char_boundary(**b))
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
result.truncate(result.len() - bpe.token_len(i));
}
}
// truncate to the know valid bytes
bytes.truncate(valid_bytes);
String::from_utf8(bytes).expect("should be valid here")
result
}

#[cfg(feature = "rand")]
Expand All @@ -638,3 +614,15 @@ pub fn select_test_string(text: &str, min_bytes: usize) -> &str {
}
&text[start..end]
}

/// Generate test bytes by concatenating random tokens.
#[cfg(feature = "rand")]
pub fn create_test_bytes(bpe: &BytePairEncoding, min_bytes: usize) -> Vec<u8> {
use rand::{thread_rng, Rng};
let mut result = Vec::new();
while result.len() < min_bytes {
let i = thread_rng().gen_range(0..bpe.num_tokens());
result.extend(bpe.token_bytes(i as u32));
}
result
}
83 changes: 26 additions & 57 deletions crates/bpe/tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#[cfg(test)]
mod tests {
use std::time::Instant;

use itertools::Itertools;
use rand::{thread_rng, Rng};
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton};
use tiktoken_rs::cl100k_base_singleton;

use bpe::appendable_encoder::AppendableEncoder;
use bpe::byte_pair_encoding::{create_test_string, BytePairEncoding};
use bpe::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
use bpe::interval_encoding::IntervalEncoding;
use bpe::prependable_encoder::PrependableEncoder;
use bpe_openai::{cl100k_base, o200k_base};
use bpe_openai::cl100k_base;

/// This test produces the output for the encoding example in the README.
#[test]
Expand Down Expand Up @@ -72,93 +70,64 @@ mod tests {
fn test_appendable_encoder() {
let bpe = &cl100k_base().bpe;
let mut enc = AppendableEncoder::new(bpe);
let input_string = create_test_string(bpe, 100);
for (i, b) in input_string.as_bytes().iter().enumerate() {
let input = create_test_bytes(bpe, 100);
for (i, b) in input.iter().enumerate() {
enc.push(*b);
assert_eq!(
enc.token_count(),
bpe.count(&input_string.as_bytes()[0..i + 1])
);
assert_eq!(enc.token_count(), bpe.count(&input[0..i + 1]));
}
}

#[test]
fn test_correctness_cl100k() {
fn test_correctness() {
// This is quite a challenging test case...
let test_string = std::str::from_utf8(&[
let input = std::str::from_utf8(&[
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
102, 102, 101, 110, 100,
])
.unwrap();
let time = Instant::now();
let bpe = &cl100k_base().bpe;
println!("{:?}", time.elapsed());
let encoded1 = cl100k_base_singleton()
.lock()
.encode_ordinary(test_string)
.into_iter()
.collect_vec();
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
assert_eq!(encoded1, encoded2);
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
assert_eq!(encoded1, encoded3);
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
assert_eq!(encoded1, encoded4);
}

#[test]
fn test_correctness_o200k() {
// This is quite a challenging test case...
let test_string = std::str::from_utf8(&[
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
102, 102, 101, 110, 100,
])
.unwrap();
let time = Instant::now();
let bpe = &o200k_base().bpe;
println!("{:?}", time.elapsed());
let encoded1 = o200k_base_singleton()
.lock()
.encode_ordinary(test_string)
.encode_ordinary(input)
.into_iter()
.collect_vec();
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
let encoded2 = bpe.encode_via_backtracking(input.as_bytes());
assert_eq!(encoded1, encoded2);
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
let encoded3 = bpe.encode_via_table(input.as_bytes());
assert_eq!(encoded1, encoded3);
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
let encoded4 = bpe.encode_via_bitfield(input.as_bytes());
assert_eq!(encoded1, encoded4);
}

#[test]
fn test_bpe_equivalence() {
let bpe = &cl100k_base().bpe;
for bytes in [10, 1000, 10000] {
for _ in 0..5 {
let test_input = create_test_string(bpe, bytes);
let encoded1 = bpe.encode_via_backtracking(test_input.as_bytes());
let encoded2 = bpe.encode_via_bitfield(test_input.as_bytes());
for _ in 0..8 {
let input = create_test_bytes(bpe, bytes);
let encoded1 = bpe.encode_via_backtracking(&input);
let encoded2 = bpe.encode_via_bitfield(&input);
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
let encoded3 = bpe.encode_via_table(&input);
assert_eq!(encoded1, encoded3, "{} {}", encoded1.len(), encoded3.len());
}
}
}

#[test]
fn test_interval_count() {
let bpe = &cl100k_base().bpe;
let text = create_test_string(bpe, 10000);
let intervals = IntervalEncoding::new(bpe, text.as_bytes());
let input = create_test_bytes(bpe, 10000);
let intervals = IntervalEncoding::new(bpe, &input);
for _ in 0..1000 {
let start = thread_rng().gen_range(0..text.len());
let end = thread_rng().gen_range(0..text.len());
let start = thread_rng().gen_range(0..input.len());
let end = thread_rng().gen_range(0..input.len());
let range = start.min(end)..start.max(end);
assert_eq!(
intervals.count(range.clone()),
bpe.encode_via_backtracking(&text.as_bytes()[range]).len()
bpe.encode_via_backtracking(&input[range]).len()
);
}
}
Expand All @@ -167,10 +136,10 @@ mod tests {
fn test_prependable_encoder() {
let bpe = &cl100k_base().bpe;
let mut enc = PrependableEncoder::new(bpe);
let input_string = create_test_string(bpe, 100);
for (i, b) in input_string.as_bytes().iter().enumerate().rev() {
let input = create_test_bytes(bpe, 100);
for (i, b) in input.iter().enumerate().rev() {
enc.push(*b);
assert_eq!(enc.token_count(), bpe.count(&input_string.as_bytes()[i..]));
assert_eq!(enc.token_count(), bpe.count(&input[i..]));
}
}
}

0 comments on commit 0cb520e

Please sign in to comment.