From faa7c5c77fbe7a0d2f1117d9fc8c407766614294 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 4 Oct 2024 16:02:47 -0400 Subject: [PATCH] automatically download exl2 model in tests fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works enable exl2 in generate.cfg cleate OutlinesExLlamaV2Tokenizer rather than monkey patching --- outlines/generate/cfg.py | 9 +------- outlines/models/exllamav2.py | 39 +++++++++++++++++++++++---------- tests/generate/test_generate.py | 10 ++++++++- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index 034a65ae5..b677040d5 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -4,7 +4,7 @@ SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, LlamaCpp, OpenAI, TransformersVision +from outlines.models import LlamaCpp, OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial @@ -41,13 +41,6 @@ def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()): return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) -@cfg.register(ExLlamaV2Model) -def cfg_exllamav2(model, cfg_str: str, sampler: Sampler = multinomial()): - raise NotImplementedError( - "Not yet available, track progress in https://github.com/dottxt-ai/outlines/pull/1010" - ) - - @cfg.register(LlamaCpp) def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()): raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer") diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index f06b7e46e..821d4e591 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -1,12 +1,13 @@ import dataclasses from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union +import torch from typing_extensions import Unpack from outlines.generate.api import GenerationParameters, SamplingParameters if TYPE_CHECKING: - from exllamav2 import ExLlamaV2Tokenizer + import torch.LongTensor from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler @@ -18,13 +19,33 @@ class ExllamaV2Params(TypedDict, total=False): max_new_tokens: List[int] +class OutlinesExLlamaV2Tokenizer: + def __init__(self, tokenizer): + self.exl2_tokenizer = tokenizer + self.vocabulary = self.exl2_tokenizer.get_piece_to_id_dict() + self.special_tokens = set(self.exl2_tokenizer.extended_piece_to_id) + self.eos_token_id = self.exl2_tokenizer.eos_token_id + + def convert_token_to_string(self, token): + return token + + def decode(self, token_ids: "torch.LongTensor") -> List[str]: + decoded = self.exl2_tokenizer.decode( + torch.tensor(token_ids), + decode_special_tokens=False, + ) + if isinstance(decoded, str): + return [decoded] + return decoded + + class ExLlamaV2Model: """Represents a `exl2` model.""" def __init__( self, generator: "ExLlamaV2DynamicGenerator", - tokenizer: "ExLlamaV2Tokenizer", + tokenizer: "OutlinesExLlamaV2Tokenizer", max_seq_len: int, ): self.generator = generator @@ -220,14 +241,6 @@ def token_generator() -> Iterator[str]: return token_generator() -# Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b -def patch_tokenizer(tokenizer): - tokenizer.vocabulary = tokenizer.piece_to_id - tokenizer.special_tokens = set(tokenizer.extended_piece_to_id) - tokenizer.convert_token_to_string = lambda t: t - return tokenizer - - def exl2( model_path: str, draft_model_path: Optional[str] = None, @@ -306,7 +319,6 @@ def exl2( print("Loading tokenizer...") tokenizer = ExLlamaV2Tokenizer(config) - tokenizer = patch_tokenizer(tokenizer) max_batch_size = 4 if paged else 1 draft_model = None @@ -337,4 +349,7 @@ def exl2( paged=paged, ) max_seq_len = cache.max_seq_len - return ExLlamaV2Model(generator, tokenizer, max_seq_len) + + outlines_tokenizer = OutlinesExLlamaV2Tokenizer(tokenizer) + outlines_exl2_model = ExLlamaV2Model(generator, outlines_tokenizer, max_seq_len) + return outlines_exl2_model diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index b36baf9a4..9c288c21e 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -22,8 +22,16 @@ def model_llamacpp(tmp_path_factory): @pytest.fixture(scope="session") def model_exllamav2(tmp_path_factory): + from huggingface_hub import snapshot_download + + tmp_dir = tmp_path_factory.mktemp("model_download") + model_path = snapshot_download( + repo_id="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4.6-exl2", + cache_dir=tmp_dir, + ) + return models.exl2( - model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + model_path=model_path, cache_q4=True, paged=False, )