Skip to content

Commit 73444b7

Browse files
Performance fix MistralTokenizer: cache special ids and tokens (#27925)
Signed-off-by: Julien Denize <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 853a8eb commit 73444b7

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
191191
# Sort the dict for convenience
192192
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
193193

194+
# Cache special tokens for faster access.
195+
self._special_token_ids = self._get_special_token_ids()
196+
self._special_token_ids_set = set(self._special_token_ids)
197+
self._special_tokens = self._get_special_tokens(self._special_token_ids)
198+
self._special_tokens_set = set(self._special_tokens)
199+
194200
# Vocab sorted by token id.
195201
self._vocab = self.tokenizer._vocab
196202
self._max_token_id = self.vocab_size - 1
@@ -210,23 +216,7 @@ def from_pretrained(
210216
)
211217
)
212218

213-
# the following attributes are set to fit vLLM's design and are used
214-
# by the structured output backends.
215-
@property
216-
def all_special_tokens_extended(self) -> list[str]:
217-
return self.all_special_tokens
218-
219-
@property
220-
def all_special_tokens(self) -> list[str]:
221-
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
222-
223-
return [
224-
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
225-
for i in self.all_special_ids
226-
]
227-
228-
@property
229-
def all_special_ids(self) -> list[int]:
219+
def _get_special_token_ids(self) -> list[int]:
230220
from mistral_common.tokens.tokenizers.sentencepiece import (
231221
SentencePieceTokenizer,
232222
)
@@ -244,6 +234,28 @@ def all_special_ids(self) -> list[int]:
244234
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
245235
return sorted(special_ids)
246236

237+
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
238+
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
239+
240+
return [
241+
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
242+
for i in all_special_ids
243+
]
244+
245+
# the following attributes are set to fit vLLM's design and are used
246+
# by the structured output backends.
247+
@property
248+
def all_special_tokens_extended(self) -> list[str]:
249+
return self.all_special_tokens
250+
251+
@property
252+
def all_special_tokens(self) -> list[str]:
253+
return self._special_tokens
254+
255+
@property
256+
def all_special_ids(self) -> list[int]:
257+
return self._special_token_ids
258+
247259
@property
248260
def bos_token_id(self) -> int:
249261
return self.tokenizer.bos_id
@@ -277,21 +289,7 @@ def truncation_side(self) -> str:
277289
raise NotImplementedError()
278290

279291
def _is_special_token_id(self, token_id: int) -> bool:
280-
from mistral_common.tokens.tokenizers.sentencepiece import (
281-
SentencePieceTokenizer,
282-
)
283-
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
284-
285-
if self.is_spm:
286-
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
287-
self.tokenizer
288-
)
289-
return token_id in self.tokenizer._control_tokens
290-
if self.is_tekken:
291-
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
292-
return token_id < self.tokenizer.num_special_tokens
293-
else:
294-
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
292+
return token_id in self._special_token_ids_set
295293

296294
def __len__(self) -> int:
297295
return self.vocab_size
@@ -405,7 +403,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str:
405403
tokens = [
406404
t
407405
for t in tokens
408-
if (t in to_decode_special_tokens or t not in self.all_special_tokens)
406+
if (t in to_decode_special_tokens or t not in self._special_tokens_set)
409407
]
410408

411409
if any(isinstance(t, bytes) for t in tokens):
@@ -489,7 +487,7 @@ def convert_ids_to_tokens(
489487
# We filtered unwanted special tokens so we can decode the rest.
490488
tokens = [
491489
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
492-
if token_id not in self.all_special_ids
490+
if token_id not in self._special_token_ids_set
493491
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
494492
for token_id in ids_kept
495493
]

0 commit comments

Comments
 (0)