Skip to content

Commit

Permalink
Merge pull request #35 from github/count-till-limit
Browse files Browse the repository at this point in the history
Add count_till_limit method on Tokenizer
  • Loading branch information
hendrikvanantwerpen authored Oct 22, 2024
2 parents 17d5c3e + c28e428 commit df29b9c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
22 changes: 21 additions & 1 deletion crates/bpe-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,41 @@ impl Tokenizer {
Ok(Self { bpe, pre })
}

/// Count the number of tokens produced when encoding the text. Applies pre-tokenization
/// before counting.
pub fn count(&self, text: &str) -> usize {
self.split(text)
.map(|piece| self.bpe.count(piece.as_bytes()))
.sum()
}

/// Returns the token count iff the total token count stays below the specified token_limit.
/// Otherwise, it returns none. This function can be faster than [`Self::count`]` when the
/// token limit is much smaller than the provided text. Applies pre-tokenization before counting.
pub fn count_till_limit(&self, text: &str, token_limit: usize) -> Option<usize> {
self.split(text)
.try_fold(token_limit, |token_limit, piece| {
self.bpe
.count_till_limit(piece.as_bytes(), token_limit)
.map(|piece_count| token_limit - piece_count)
})
}

/// Returns the tokens for the encoding of the given text. Applies pre-tokenization before
/// encoding.
pub fn encode(&self, text: &str) -> Vec<u32> {
self.split(text)
.flat_map(|piece| self.bpe.encode_via_backtracking(piece.as_bytes()))
.collect()
}

/// Returns the text corresponding to the given encoding if it is valid UTF-8. Otherwise,
/// returns none.
pub fn decode(&self, tokens: &[u32]) -> Option<String> {
String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
}

/// Returns an iterator with the text pieces resulting from pre-tokenization. If this
/// tokenizer does not have pre-tokenization, the iterator returns the full text.
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
match &self.pre {
Some(pre) => Either::Left(pre.split(text)),
Expand Down Expand Up @@ -124,6 +143,7 @@ impl Pretokenizer {
Ok(Self { pat, lookahead })
}

/// Returns an iterator with the text pieces after splitting with the regular expression.
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
Splits {
pat: &self.pat,
Expand Down
1 change: 0 additions & 1 deletion crates/bpe/benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ test = true
[dependencies]
bpe = { path = "../../bpe" }
bpe-openai = { path = "../../bpe-openai" }
bpe-tests = { path = "../tests" }
criterion = "0.5"
rand = "0.8"
tiktoken-rs = "0.6"
Expand Down
2 changes: 1 addition & 1 deletion crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ impl BytePairEncoding {
}

/// Returns the token count iff the total token count stays below the specified `token_limit`.
/// Otherwise, it returns false.
/// Otherwise, it returns none.
/// This function can be faster than `count` when the token_limit is much smaller than the provided text.
pub fn count_till_limit(&self, text: &[u8], token_limit: usize) -> Option<usize> {
let mut enc = BacktrackEncoder::new(self, text);
Expand Down

0 comments on commit df29b9c

Please sign in to comment.