Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 32 additions & 34 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
# Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

# Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids()
self._special_token_ids_set = set(self._special_token_ids)
self._special_tokens = self._get_special_tokens(self._special_token_ids)
self._special_tokens_set = set(self._special_tokens)

# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1
Expand All @@ -210,23 +216,7 @@ def from_pretrained(
)
)

# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
return self.all_special_tokens

@property
def all_special_tokens(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in self.all_special_ids
]

@property
def all_special_ids(self) -> list[int]:
def _get_special_token_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
Expand All @@ -244,6 +234,28 @@ def all_special_ids(self) -> list[int]:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)

def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in all_special_ids
]

# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
return self.all_special_tokens

@property
def all_special_tokens(self) -> list[str]:
return self._special_tokens

@property
def all_special_ids(self) -> list[int]:
return self._special_token_ids

@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id
Expand Down Expand Up @@ -277,21 +289,7 @@ def truncation_side(self) -> str:
raise NotImplementedError()

def _is_special_token_id(self, token_id: int) -> bool:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

if self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
return token_id in self.tokenizer._control_tokens
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
return token_id < self.tokenizer.num_special_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return token_id in self._special_token_ids_set

def __len__(self) -> int:
return self.vocab_size
Expand Down Expand Up @@ -405,7 +403,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str:
tokens = [
t
for t in tokens
if (t in to_decode_special_tokens or t not in self.all_special_tokens)
if (t in to_decode_special_tokens or t not in self._special_tokens_set)
]

if any(isinstance(t, bytes) for t in tokens):
Expand Down Expand Up @@ -489,7 +487,7 @@ def convert_ids_to_tokens(
# We filtered unwanted special tokens so we can decode the rest.
tokens = [
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
if token_id not in self.all_special_ids
if token_id not in self._special_token_ids_set
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
for token_id in ids_kept
]
Expand Down