diff --git a/tokenizer.py b/tokenizer.py index c62a0c5..21a87c8 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -21,6 +21,15 @@ def bos_id(self): def eos_id(self): raise NotImplementedError("This method should be overridden by subclasses.") + def id_to_piece(self, token_id): + raise NotImplementedError("This method should be overridden by subclasses.") + + def piece_to_id(self, token_str): + raise NotImplementedError("This method should be overridden by subclasses.") + + def is_special_token(self, token_id): + raise NotImplementedError("This method should be overridden by subclasses.") + class SentencePieceWrapper(TokenizerInterface): def __init__(self, model_path): super().__init__(model_path) @@ -38,6 +47,17 @@ def bos_id(self): def eos_id(self): return self.processor.eos_id() + def id_to_piece(self, token_id): + return self.processor.id_to_piece(token_id).replace("▁", " ") + + def piece_to_id(self, token_str): + return self.processor.piece_to_id(token_str.replace(" ", "▁")) + + def is_special_token(self, token_id): + return self.processor.IsControl(token_id) \ + or self.processor.IsUnknown(token_id) \ + or self.processor.IsUnused(token_id) + class TiktokenWrapper(TokenizerInterface): """ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. @@ -53,7 +73,7 @@ def __init__(self, model_path): super().__init__(model_path) assert os.path.isfile(model_path), str(model_path) mergeable_ranks = load_tiktoken_bpe(str(model_path)) - num_base_tokens = len(mergeable_ranks) + self.num_base_tokens = len(mergeable_ranks) special_tokens = [ "<|begin_of_text|>", "<|end_of_text|>", @@ -70,7 +90,7 @@ def __init__(self, model_path): for i in range(5, self.num_reserved_special_tokens - 5) ] self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) + token: self.num_base_tokens + i for i, token in enumerate(special_tokens) } self.model = tiktoken.Encoding( name=Path(model_path).name, @@ -94,6 +114,15 @@ def bos_id(self): def eos_id(self): return self._eos_id + def id_to_piece(self, token_id): + return self.model.decode([token_id]) + + def piece_to_id(self, token_str): + return self.model.encode_single_token(token_str) + + def is_special_token(self, token_id): + return token_id >= self.num_base_tokens and token_id < self.num_base_tokens + len(self.special_tokens) + def get_tokenizer(tokenizer_model_path, model_name): """ Factory function to get the appropriate tokenizer based on the model name.