Skip to content

Commit b42989e

Browse files
Simplify test string generation and reinstate random byte inputs where possible
1 parent c75707f commit b42989e

File tree

3 files changed

+49
-108
lines changed

3 files changed

+49
-108
lines changed

crates/bpe-openai/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ regex-automata = "0.4"
1919
rmp-serde = "1"
2020

2121
[dev-dependencies]
22-
tiktoken-rs = "0.6"
2322
bpe = { version = "0.1.0", path = "../bpe", features = ["rand"] }
23+
tiktoken-rs = "0.6"
2424

2525
[build-dependencies]
2626
base64 = "0.22.1"

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -553,61 +553,21 @@ impl BytePairEncoding {
553553
}
554554
}
555555

556-
#[cfg(feature = "rand")]
557-
fn is_char_boundary(b: u8) -> bool {
558-
// Single byte encodings satisfy the bit pattern 0xxxxxxx, i.e. b < 128
559-
// Continuation bytes satisfy the bit pattern 10xxxxxx, i.e. b < 192
560-
// The rest are bytes belonging to the first byte of multi byte encodings (11xxxxxx): b >= 192
561-
// When interpreting the byte representation as signed integers, then numbers in the range 128..192
562-
// correspond to the smallest representable numbers. I.e. the two ranges [0, 128) and [192, 256) can
563-
// be tested with a single signed comparison.
564-
b as i8 >= -0x40 // NB: b < 128 || b >= 192
565-
}
566-
556+
/// Generate a test string by concatenating random tokens.
567557
#[cfg(feature = "rand")]
568558
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
569559
use rand::{thread_rng, Rng};
570-
// the bytes we accumulated thus far
571-
let mut bytes = Vec::new();
572-
// the tokens we added so we can backtrack
573-
let mut tokens = Vec::new();
574-
// the number of valid UTF-8 bytes
575-
let mut valid_bytes = 0;
576-
'keep: while valid_bytes < min_bytes {
577-
// try a few times to find a suitable token
578-
for _ in 0..8 {
579-
// pick a random token and provisionally add it
580-
let i = thread_rng().gen_range(0..bpe.num_tokens());
581-
bytes.extend(bpe.token_bytes(i as u32));
582-
// test if the additional bytes are valid utf-8
583-
// the last character is not included, because it may be incomplete
584-
let last = bytes
585-
.iter()
586-
.rev()
587-
.find_position(|b| is_char_boundary(**b))
588-
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
589-
assert!(last >= valid_bytes);
590-
if std::str::from_utf8(&bytes[valid_bytes..last]).is_ok() {
591-
tokens.push(i);
592-
valid_bytes = last;
593-
continue 'keep;
594-
} else {
595-
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
596-
}
597-
}
598-
// we didn't find anything after a few tries, backtrack
599-
if let Some(i) = tokens.pop() {
600-
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
601-
valid_bytes = bytes
602-
.iter()
603-
.rev()
604-
.find_position(|b| is_char_boundary(**b))
605-
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
560+
let mut result = String::new();
561+
while result.len() < min_bytes {
562+
let i = thread_rng().gen_range(0..bpe.num_tokens());
563+
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
564+
// token set. The chance of constructing a valid UTF-8 character across a token boundary
565+
// by picking random tokens is so small that it is unlikely to happen anyway.
566+
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i as u32)) {
567+
result.push_str(token);
606568
}
607569
}
608-
// truncate to the known valid bytes
609-
bytes.truncate(valid_bytes);
610-
String::from_utf8(bytes).expect("should be valid here")
570+
result
611571
}
612572

613573
#[cfg(feature = "rand")]
@@ -623,3 +583,15 @@ pub fn select_test_string(text: &str, min_bytes: usize) -> &str {
623583
}
624584
&text[start..end]
625585
}
586+
587+
/// Generate test bytes by concatenating random tokens.
588+
#[cfg(feature = "rand")]
589+
pub fn create_test_bytes(bpe: &BytePairEncoding, min_bytes: usize) -> Vec<u8> {
590+
use rand::{thread_rng, Rng};
591+
let mut result = Vec::new();
592+
while result.len() < min_bytes {
593+
let i = thread_rng().gen_range(0..bpe.num_tokens());
594+
result.extend(bpe.token_bytes(i as u32));
595+
}
596+
result
597+
}

crates/bpe/tests/src/lib.rs

Lines changed: 26 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
#[cfg(test)]
22
mod tests {
3-
use std::time::Instant;
4-
53
use itertools::Itertools;
64
use rand::{thread_rng, Rng};
7-
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton};
5+
use tiktoken_rs::cl100k_base_singleton;
86

97
use bpe::appendable_encoder::AppendableEncoder;
10-
use bpe::byte_pair_encoding::{create_test_string, BytePairEncoding};
8+
use bpe::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
119
use bpe::interval_encoding::IntervalEncoding;
1210
use bpe::prependable_encoder::PrependableEncoder;
13-
use bpe_openai::{cl100k_base, o200k_base};
11+
use bpe_openai::cl100k_base;
1412

1513
/// This test produces the output for the encoding example in the README.
1614
#[test]
@@ -72,93 +70,64 @@ mod tests {
7270
fn test_appendable_encoder() {
7371
let bpe = &cl100k_base().bpe;
7472
let mut enc = AppendableEncoder::new(bpe);
75-
let input_string = create_test_string(bpe, 100);
76-
for (i, b) in input_string.as_bytes().iter().enumerate() {
73+
let input = create_test_bytes(bpe, 100);
74+
for (i, b) in input.iter().enumerate() {
7775
enc.push(*b);
78-
assert_eq!(
79-
enc.token_count(),
80-
bpe.count(&input_string.as_bytes()[0..i + 1])
81-
);
76+
assert_eq!(enc.token_count(), bpe.count(&input[0..i + 1]));
8277
}
8378
}
8479

8580
#[test]
86-
fn test_correctness_cl100k() {
81+
fn test_correctness() {
8782
// This is quite a challenging test case...
88-
let test_string = std::str::from_utf8(&[
83+
let input = std::str::from_utf8(&[
8984
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
9085
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
9186
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
9287
102, 102, 101, 110, 100,
9388
])
9489
.unwrap();
95-
let time = Instant::now();
9690
let bpe = &cl100k_base().bpe;
97-
println!("{:?}", time.elapsed());
9891
let encoded1 = cl100k_base_singleton()
9992
.lock()
100-
.encode_ordinary(test_string)
101-
.into_iter()
102-
.collect_vec();
103-
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
104-
assert_eq!(encoded1, encoded2);
105-
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
106-
assert_eq!(encoded1, encoded3);
107-
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
108-
assert_eq!(encoded1, encoded4);
109-
}
110-
111-
#[test]
112-
fn test_correctness_o200k() {
113-
// This is quite a challenging test case...
114-
let test_string = std::str::from_utf8(&[
115-
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
116-
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
117-
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
118-
102, 102, 101, 110, 100,
119-
])
120-
.unwrap();
121-
let time = Instant::now();
122-
let bpe = &o200k_base().bpe;
123-
println!("{:?}", time.elapsed());
124-
let encoded1 = o200k_base_singleton()
125-
.lock()
126-
.encode_ordinary(test_string)
93+
.encode_ordinary(input)
12794
.into_iter()
12895
.collect_vec();
129-
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
96+
let encoded2 = bpe.encode_via_backtracking(input.as_bytes());
13097
assert_eq!(encoded1, encoded2);
131-
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
98+
let encoded3 = bpe.encode_via_table(input.as_bytes());
13299
assert_eq!(encoded1, encoded3);
133-
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
100+
let encoded4 = bpe.encode_via_bitfield(input.as_bytes());
134101
assert_eq!(encoded1, encoded4);
135102
}
136103

137104
#[test]
138105
fn test_bpe_equivalence() {
139106
let bpe = &cl100k_base().bpe;
140107
for bytes in [10, 1000, 10000] {
141-
for _ in 0..5 {
142-
let test_input = create_test_string(bpe, bytes);
143-
let encoded1 = bpe.encode_via_backtracking(test_input.as_bytes());
144-
let encoded2 = bpe.encode_via_bitfield(test_input.as_bytes());
108+
for _ in 0..8 {
109+
let input = create_test_bytes(bpe, bytes);
110+
let encoded1 = bpe.encode_via_backtracking(&input);
111+
let encoded2 = bpe.encode_via_bitfield(&input);
145112
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
113+
let encoded3 = bpe.encode_via_table(&input);
114+
assert_eq!(encoded1, encoded3, "{} {}", encoded1.len(), encoded3.len());
146115
}
147116
}
148117
}
149118

150119
#[test]
151120
fn test_interval_count() {
152121
let bpe = &cl100k_base().bpe;
153-
let text = create_test_string(bpe, 10000);
154-
let intervals = IntervalEncoding::new(bpe, text.as_bytes());
122+
let input = create_test_bytes(bpe, 10000);
123+
let intervals = IntervalEncoding::new(bpe, &input);
155124
for _ in 0..1000 {
156-
let start = thread_rng().gen_range(0..text.len());
157-
let end = thread_rng().gen_range(0..text.len());
125+
let start = thread_rng().gen_range(0..input.len());
126+
let end = thread_rng().gen_range(0..input.len());
158127
let range = start.min(end)..start.max(end);
159128
assert_eq!(
160129
intervals.count(range.clone()),
161-
bpe.encode_via_backtracking(&text.as_bytes()[range]).len()
130+
bpe.encode_via_backtracking(&input[range]).len()
162131
);
163132
}
164133
}
@@ -167,10 +136,10 @@ mod tests {
167136
fn test_prependable_encoder() {
168137
let bpe = &cl100k_base().bpe;
169138
let mut enc = PrependableEncoder::new(bpe);
170-
let input_string = create_test_string(bpe, 100);
171-
for (i, b) in input_string.as_bytes().iter().enumerate().rev() {
139+
let input = create_test_bytes(bpe, 100);
140+
for (i, b) in input.iter().enumerate().rev() {
172141
enc.push(*b);
173-
assert_eq!(enc.token_count(), bpe.count(&input_string.as_bytes()[i..]));
142+
assert_eq!(enc.token_count(), bpe.count(&input[i..]));
174143
}
175144
}
176145
}

0 commit comments

Comments
 (0)