Skip to content

Commit

Permalink
fix models.llamacpp vocabulary normalization function
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 authored and rlouf committed Jun 22, 2024
1 parent 60e89f5 commit 8dcd24e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
12 changes: 11 additions & 1 deletion outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def __init__(self, model: "Llama"):
self.tokenizer = model.tokenizer()

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
self._hf_tokenizer = None
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
self._hf_tokenizer = model.tokenizer_.hf_tokenizer
except AttributeError:
# ###
for t in range(model.n_vocab()):
Expand Down Expand Up @@ -71,7 +73,15 @@ def encode(
return token_ids, attention_mask

def convert_token_to_string(self, token: str) -> str:
return token
if self._hf_tokenizer is not None:
from transformers.file_utils import SPIECE_UNDERLINE

token_str = self._hf_tokenizer.convert_tokens_to_string([token])
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
token_str = " " + token_str
return token_str
else:
return token

def __eq__(self, other):
if not isinstance(other, LlamaCppTokenizer):
Expand Down
22 changes: 22 additions & 0 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,25 @@ def test_RegexGuide_caching(model, temp_cache_dir):
assert re.fullmatch(regex, structured)
assert re.fullmatch(regex, structured_2)
assert structured != structured_2


def test_tokenizer_vocabulary_decode_sanity():
"""Assert the decoded newline token (198) is the same as the normalized vocab token"""
import llama_cpp

model = models.llamacpp(
"bartowski/Meta-Llama-3-8B-Instruct-GGUF",
"Meta-Llama-3-8B-Instruct-IQ1_M.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
"NousResearch/Hermes-2-Pro-Llama-3-8B"
),
)
tokenizer = generate.regex(model, "a").logits_processor.tokenizer

decoded_nl_token = tokenizer.decode([198])[0]
vocab_nl_token = tokenizer.convert_token_to_string(
[token for token, token_id in tokenizer.vocabulary.items() if token_id == 198][
0
]
)
assert decoded_nl_token == vocab_nl_token

0 comments on commit 8dcd24e

Please sign in to comment.