@@ -190,8 +190,12 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
190190 }
191191 # Sort the dict for convenience
192192 self ._vocab_dict = dict (sorted (self ._vocab_dict .items (), key = lambda x : x [1 ]))
193+
194+ # Cache special tokens for faster access.
193195 self ._special_token_ids = self ._get_special_token_ids ()
196+ self ._special_token_ids_set = set (self ._special_token_ids )
194197 self ._special_tokens = self ._get_special_tokens ()
198+ self ._special_tokens_set = set (self ._special_tokens )
195199
196200 # Vocab sorted by token id.
197201 self ._vocab = self .tokenizer ._vocab
@@ -413,7 +417,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str:
413417 tokens = [
414418 t
415419 for t in tokens
416- if (t in to_decode_special_tokens or t not in self .all_special_tokens )
420+ if (t in to_decode_special_tokens or t not in self ._special_tokens_set )
417421 ]
418422
419423 if any (isinstance (t , bytes ) for t in tokens ):
@@ -497,7 +501,7 @@ def convert_ids_to_tokens(
497501 # We filtered unwanted special tokens so we can decode the rest.
498502 tokens = [
499503 self .tokenizer .id_to_byte_piece (token_id , SpecialTokenPolicy .KEEP )
500- if token_id not in self .all_special_ids
504+ if token_id not in self ._special_token_ids_set
501505 else self .tokenizer .decode ([token_id ], SpecialTokenPolicy .KEEP )
502506 for token_id in ids_kept
503507 ]
0 commit comments