Skip to content

Commit

Permalink
Making TokenizerInterface to be more usable for the user's code.
Browse files Browse the repository at this point in the history
Adding id_to_piece, piece_to_id and is_special_token functionality to TokenizerInterface and the corresponding implementations.
  • Loading branch information
Artyom17 committed Apr 30, 2024
1 parent 30d69b3 commit 3733d15
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import bisect
import os
import sentencepiece as spm
import tiktoken
Expand All @@ -21,6 +22,15 @@ def bos_id(self):
def eos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

def id_to_piece(self, token_id):
raise NotImplementedError("This method should be overridden by subclasses.")

def piece_to_id(self, token_str):
raise NotImplementedError("This method should be overridden by subclasses.")

def is_special_token(self, token_id):
raise NotImplementedError("This method should be overridden by subclasses.")

class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
Expand All @@ -38,6 +48,17 @@ def bos_id(self):
def eos_id(self):
return self.processor.eos_id()

def id_to_piece(self, token_id):
return self.processor.id_to_piece(token_id).replace("▁", " ")

def piece_to_id(self, token_str):
return self.processor.piece_to_id(token_str.replace(" ", "▁"))

def is_special_token(self, token_id):
return self.processor.IsControl(token_id) \
or self.processor.IsUnknown(token_id) \
or self.processor.IsUnused(token_id)

class TiktokenWrapper(TokenizerInterface):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
Expand All @@ -53,7 +74,7 @@ def __init__(self, model_path):
super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
self.num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
Expand All @@ -70,7 +91,7 @@ def __init__(self, model_path):
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
token: self.num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
Expand All @@ -94,6 +115,15 @@ def bos_id(self):
def eos_id(self):
return self._eos_id

def id_to_piece(self, token_id):
return self.model.decode([token_id])

def piece_to_id(self, token_str):
return self.model.encode_single_token(token_str)

def is_special_token(self, token_id):
return token_id >= self.num_base_tokens and token_id < self.num_base_tokens + len(self.special_tokens)

def get_tokenizer(tokenizer_model_path, model_name):
"""
Factory function to get the appropriate tokenizer based on the model name.
Expand Down

0 comments on commit 3733d15

Please sign in to comment.