@@ -191,6 +191,12 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
191191 # Sort the dict for convenience
192192 self ._vocab_dict = dict (sorted (self ._vocab_dict .items (), key = lambda x : x [1 ]))
193193
194+ # Cache special tokens for faster access.
195+ self ._special_token_ids = self ._get_special_token_ids ()
196+ self ._special_token_ids_set = set (self ._special_token_ids )
197+ self ._special_tokens = self ._get_special_tokens (self ._special_token_ids )
198+ self ._special_tokens_set = set (self ._special_tokens )
199+
194200 # Vocab sorted by token id.
195201 self ._vocab = self .tokenizer ._vocab
196202 self ._max_token_id = self .vocab_size - 1
@@ -210,23 +216,7 @@ def from_pretrained(
210216 )
211217 )
212218
213- # the following attributes are set to fit vLLM's design and are used
214- # by the structured output backends.
215- @property
216- def all_special_tokens_extended (self ) -> list [str ]:
217- return self .all_special_tokens
218-
219- @property
220- def all_special_tokens (self ) -> list [str ]:
221- from mistral_common .tokens .tokenizers .base import SpecialTokenPolicy
222-
223- return [
224- self .tokenizer .decode ([i ], special_token_policy = SpecialTokenPolicy .KEEP )
225- for i in self .all_special_ids
226- ]
227-
228- @property
229- def all_special_ids (self ) -> list [int ]:
219+ def _get_special_token_ids (self ) -> list [int ]:
230220 from mistral_common .tokens .tokenizers .sentencepiece import (
231221 SentencePieceTokenizer ,
232222 )
@@ -244,6 +234,28 @@ def all_special_ids(self) -> list[int]:
244234 raise ValueError (f"Unknown tokenizer type: { type (self .tokenizer )} " )
245235 return sorted (special_ids )
246236
237+ def _get_special_tokens (self , all_special_ids : list [int ]) -> list [str ]:
238+ from mistral_common .tokens .tokenizers .base import SpecialTokenPolicy
239+
240+ return [
241+ self .tokenizer .decode ([i ], special_token_policy = SpecialTokenPolicy .KEEP )
242+ for i in all_special_ids
243+ ]
244+
245+ # the following attributes are set to fit vLLM's design and are used
246+ # by the structured output backends.
247+ @property
248+ def all_special_tokens_extended (self ) -> list [str ]:
249+ return self .all_special_tokens
250+
251+ @property
252+ def all_special_tokens (self ) -> list [str ]:
253+ return self ._special_tokens
254+
255+ @property
256+ def all_special_ids (self ) -> list [int ]:
257+ return self ._special_token_ids
258+
247259 @property
248260 def bos_token_id (self ) -> int :
249261 return self .tokenizer .bos_id
@@ -277,21 +289,7 @@ def truncation_side(self) -> str:
277289 raise NotImplementedError ()
278290
279291 def _is_special_token_id (self , token_id : int ) -> bool :
280- from mistral_common .tokens .tokenizers .sentencepiece import (
281- SentencePieceTokenizer ,
282- )
283- from mistral_common .tokens .tokenizers .tekken import Tekkenizer
284-
285- if self .is_spm :
286- assert isinstance (self .tokenizer , SentencePieceTokenizer ), type (
287- self .tokenizer
288- )
289- return token_id in self .tokenizer ._control_tokens
290- if self .is_tekken :
291- assert isinstance (self .tokenizer , Tekkenizer ), type (self .tokenizer )
292- return token_id < self .tokenizer .num_special_tokens
293- else :
294- raise ValueError (f"Unknown tokenizer type: { type (self .tokenizer )} " )
292+ return token_id in self ._special_token_ids_set
295293
296294 def __len__ (self ) -> int :
297295 return self .vocab_size
@@ -405,7 +403,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str:
405403 tokens = [
406404 t
407405 for t in tokens
408- if (t in to_decode_special_tokens or t not in self .all_special_tokens )
406+ if (t in to_decode_special_tokens or t not in self ._special_tokens_set )
409407 ]
410408
411409 if any (isinstance (t , bytes ) for t in tokens ):
@@ -489,7 +487,7 @@ def convert_ids_to_tokens(
489487 # We filtered unwanted special tokens so we can decode the rest.
490488 tokens = [
491489 self .tokenizer .id_to_byte_piece (token_id , SpecialTokenPolicy .KEEP )
492- if token_id not in self .all_special_ids
490+ if token_id not in self ._special_token_ids_set
493491 else self .tokenizer .decode ([token_id ], SpecialTokenPolicy .KEEP )
494492 for token_id in ids_kept
495493 ]
0 commit comments