@@ -194,7 +194,7 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
194194 # Cache special tokens for faster access.
195195 self ._special_token_ids = self ._get_special_token_ids ()
196196 self ._special_token_ids_set = set (self ._special_token_ids )
197- self ._special_tokens = self ._get_special_tokens ()
197+ self ._special_tokens = self ._get_special_tokens (self . _special_token_ids )
198198 self ._special_tokens_set = set (self ._special_tokens )
199199
200200 # Vocab sorted by token id.
@@ -234,12 +234,12 @@ def _get_special_token_ids(self) -> list[int]:
234234 raise ValueError (f"Unknown tokenizer type: { type (self .tokenizer )} " )
235235 return sorted (special_ids )
236236
237- def _get_special_tokens (self ) -> list [str ]:
237+ def _get_special_tokens (self , all_special_ids : list [ int ] ) -> list [str ]:
238238 from mistral_common .tokens .tokenizers .base import SpecialTokenPolicy
239239
240240 return [
241241 self .tokenizer .decode ([i ], special_token_policy = SpecialTokenPolicy .KEEP )
242- for i in self . all_special_ids
242+ for i in all_special_ids
243243 ]
244244
245245 # the following attributes are set to fit vLLM's design and are used
@@ -289,21 +289,7 @@ def truncation_side(self) -> str:
289289 raise NotImplementedError ()
290290
291291 def _is_special_token_id (self , token_id : int ) -> bool :
292- from mistral_common .tokens .tokenizers .sentencepiece import (
293- SentencePieceTokenizer ,
294- )
295- from mistral_common .tokens .tokenizers .tekken import Tekkenizer
296-
297- if self .is_spm :
298- assert isinstance (self .tokenizer , SentencePieceTokenizer ), type (
299- self .tokenizer
300- )
301- return token_id in self .tokenizer ._control_tokens
302- if self .is_tekken :
303- assert isinstance (self .tokenizer , Tekkenizer ), type (self .tokenizer )
304- return token_id < self .tokenizer .num_special_tokens
305- else :
306- raise ValueError (f"Unknown tokenizer type: { type (self .tokenizer )} " )
292+ return token_id in self ._special_token_ids_set
307293
308294 def __len__ (self ) -> int :
309295 return self .vocab_size
0 commit comments