Skip to content

Commit

Permalink
Refactor the llama.cpp interface
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 29, 2024
1 parent 125f05d commit e4575ed
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 251 deletions.
2 changes: 1 addition & 1 deletion outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .anthropic import Anthropic
from .exllamav2 import ExLlamaV2Model, exl2
from .gemini import Gemini
from .llamacpp import LlamaCpp, llamacpp
from .llamacpp import LlamaCpp
from .mlxlm import MLXLM, mlxlm
from .openai import AzureOpenAI, OpenAI
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
Expand Down
300 changes: 50 additions & 250 deletions outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
import dataclasses
import pickle
import warnings
from typing import (
TYPE_CHECKING,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
TypedDict,
Union,
)

from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters
from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Union

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -107,241 +93,107 @@ def __setstate__(self, state):
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")


class LlamaCppParams(TypedDict, total=False):
suffix: Optional[str]
temperature: float
top_p: float
min_p: float
typical_p: float
seed: int
max_tokens: int
logits_processor: "LogitsProcessorList"
stop: Optional[Union[str, List[str]]]
frequence_penalty: float
presence_penalty: float
repeat_penalty: float
top_k: int
tfs_z: float
mirostat_mode: int
mirostat_tau: float
mirostat_eta: float
stream: bool


class LlamaCpp:
"""Represents a model provided by the `llama-cpp-python` library.
We wrap models from model providing libraries in order to give all of
them the same interface in Outlines and allow users to easily switch
between providers. This class wraps the `llama_cpp.Llama` class from the
`llama-cpp-python` library.
"""

def __init__(self, model: "Llama"):
self.model = model
"""Represents a model provided by the `llama-cpp-python` library."""

@property
def tokenizer(self):
return LlamaCppTokenizer(self.model)
def __init__(self, model_path: Union[str, "Llama"], **kwargs):
from llama_cpp import Llama

def prepare_generation_parameters(
self,
generation_parameters: GenerationParameters,
sampling_parameters: SamplingParameters,
structure_logits_processor,
**llama_cpp_params: Unpack[LlamaCppParams],
):
"""Prepare the generation parameters.
if isinstance(model_path, Llama):
self.model = model_path
else:
# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
if "tokenizer" not in kwargs:
warnings.warn(
"The pre-tokenizer in `llama.cpp` handles unicode improperly "
+ "(https://github.com/ggerganov/llama.cpp/pull/5613)\n"
+ "Outlines may raise a `RuntimeError` when building the regex index.\n"
+ "To circumvent this error when using `models.llamacpp()` you may pass the argument"
+ "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(<hf_repo_id>)`\n"
)

`llama-cpp-python` uses different default values
self.model = Llama(model_path, **kwargs)

"""
from llama_cpp import LogitsProcessorList
self.tokenizer = LlamaCppTokenizer(self.model)

max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)

# We update `llama_cpp_params` with the values the user passed to the
# generator.
if "stop" not in llama_cpp_params:
llama_cpp_params["stop"] = stop_at
if "seed" not in llama_cpp_params:
llama_cpp_params["seed"] = seed

# Somehow `llama-cpp-python` generates `max_tokens + 1` tokens
if "max_tokens" not in llama_cpp_params:
if max_tokens is None:
llama_cpp_params["max_tokens"] = -1 # indicates unlimited tokens
else:
llama_cpp_params["max_tokens"] = max_tokens - 1
else:
llama_cpp_params["max_tokens"] = llama_cpp_params["max_tokens"] - 1
@classmethod
def from_pretrained(cls, repo_id, filename, **kwargs):
from llama_cpp import Llama

sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)

# We update the `llama_cpp_params` with the sampling values that
# were specified by the user via the `Sampler` class, unless they
# are also specified in `llama_cpp_params`. We also disable other
# sampling methods that are enabled by default and reset the temperature
# value.
#
# See https://github.com/ggerganov/llama.cpp/blob/e11a8999b5690f810c2c99c14347f0834e68c524/common/sampling.h#L22
# for the default values in `llama.cpp` and indications to disable the sampling modes.
# Mirostat sampling, tail-free sampling and all penalties are disabled by default.
#
# See https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__
# for default values in `llama-cpp-python`
if sampler == "beam_search":
raise NotImplementedError(
"The `llama_cpp_python` library does not support Beam Search."
)
if num_samples != 1:
raise NotImplementedError(
"The `llama_cpp_python` library does not allow to take several samples."
# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
if "tokenizer" not in kwargs:
warnings.warn(
"The pre-tokenizer in `llama.cpp` handles unicode improperly "
+ "(https://github.com/ggerganov/llama.cpp/pull/5613)\n"
+ "Outlines may raise a `RuntimeError` when building the regex index.\n"
+ "To circumvent this error when using `models.llamacpp()` you may pass the argument"
+ "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(<hf_repo_id>)`\n"
)
if "top_p" not in llama_cpp_params:
if top_p is not None:
llama_cpp_params["top_p"] = top_p
else:
llama_cpp_params["top_p"] = 1.0

if "min_p" not in llama_cpp_params:
llama_cpp_params["min_p"] = 0.0

if "top_k" not in llama_cpp_params:
if top_k is not None:
llama_cpp_params["top_k"] = top_k
else:
llama_cpp_params["top_k"] = -1

if "temperature" not in llama_cpp_params:
if temperature is not None:
llama_cpp_params["temperature"] = temperature
else:
llama_cpp_params["temperature"] = 1.0

if "repeat_penalty" not in llama_cpp_params:
llama_cpp_params["repeat_penalty"] = 1.0

# The choice to stream or not should happen via the high-level API
llama_cpp_params["stream"] = False

if structure_logits_processor is not None:
if "logits_processor" in llama_cpp_params:
llama_cpp_params["logits_processor"].append(structure_logits_processor)
else:
llama_cpp_params["logits_processor"] = LogitsProcessorList(
[structure_logits_processor]
)

return llama_cpp_params
model = Llama.from_pretrained(repo_id, filename, **kwargs)
return cls(model)

def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
structure_logits_processor,
sampling_parameters: SamplingParameters,
**llama_cpp_params: Unpack[LlamaCppParams],
) -> str:
def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str:
"""Generate text using `llama-cpp-python`.
Arguments
---------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
prompt
A prompt.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
llama_cpp_params
Keyword arguments that can be passed to
`llama_cpp_python.Llama.__call__`. The values in `llama_cpp_params`
supersede the values of the parameters in `generation_parameters` and
`sampling_parameters`. See the `llama_cpp_python` documentation for
a list of possible values: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__
Returns
-------
The generated text.
"""
if not isinstance(prompts, str):
from llama_cpp import LogitsProcessorList

if not isinstance(prompt, str):
raise NotImplementedError(
"The `llama-cpp-python` library does not support batch inference."
)

llama_cpp_params = self.prepare_generation_parameters(
generation_parameters,
sampling_parameters,
structure_logits_processor,
**llama_cpp_params,
completion = self.model(
prompt,
logits_processor=LogitsProcessorList([logits_processor])
** inference_kwargs,
)
completion = self.model(prompts, **llama_cpp_params)
result = completion["choices"][0]["text"]

self.model.reset()

return result

def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
structure_logits_processor,
sampling_parameters: SamplingParameters,
**llama_cpp_params: Unpack[LlamaCppParams],
self, prompt: str, logits_processor, **inference_kwargs
) -> Iterator[str]:
"""Stream text using `llama-cpp-python`.
Arguments
---------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
prompt
A prompt.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
llama_cpp_params
Keyword arguments that can be passed to
`llama_cpp_python.Llama.__call__`. The values in `llama_cpp_params`
supersede the values of the parameters in `generation_parameters` and
`sampling_parameters`. See the `llama_cpp_python` documentation for
a list of possible values: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__
Returns
-------
A generator that return strings.
"""

if not isinstance(prompts, str):
if not isinstance(prompt, str):
raise NotImplementedError(
"The `llama-cpp-python` library does not support batch inference."
)

llama_cpp_params = self.prepare_generation_parameters(
generation_parameters,
sampling_parameters,
structure_logits_processor,
**llama_cpp_params,
generator = self.model(
prompt,
logits_processor=LogitsProcessorList([logits_processor])
** inference_kwargs,
)
llama_cpp_params["stream"] = True
generator = self.model(prompts, **llama_cpp_params)

def token_generator() -> Iterator[str]:
while True:
Expand All @@ -353,55 +205,3 @@ def token_generator() -> Iterator[str]:
return

return token_generator()

def load_lora(self, adapter_path: str):
if self.model._model.apply_lora_from_file(
adapter_path,
1.0,
):
raise RuntimeError(f"Failed to apply LoRA from lora path: {adapter_path}")


def llamacpp(
repo_id: str, filename: Optional[str] = None, **llamacpp_model_params
) -> LlamaCpp:
"""Load a model from the `llama-cpp-python` library.
We use the `Llama.from_pretrained` classmethod that downloads models
directly from the HuggingFace hub, instead of asking users to specify
a path to the downloaded model. One can still load a local model
by initializing `llama_cpp.Llama` directly.
Arguments
---------
repo_id
The name of the model repository.
filename:
A filename of glob pattern to match the model file in the repo.
llama_cpp_model_params
Llama-specific model parameters. See the `llama-cpp-python` documentation
for the full list: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__
"""
from llama_cpp import Llama

# Default to using the model's full context length
if "n_ctx" not in llamacpp_model_params:
llamacpp_model_params["n_ctx"] = 0

if "verbose" not in llamacpp_model_params:
llamacpp_model_params["verbose"] = False

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
if "tokenizer" not in llamacpp_model_params:
warnings.warn(
"The pre-tokenizer in `llama.cpp` handles unicode improperly "
+ "(https://github.com/ggerganov/llama.cpp/pull/5613)\n"
+ "Outlines may raise a `RuntimeError` when building the regex index.\n"
+ "To circumvent this error when using `models.llamacpp()` you may pass the argument"
+ "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(<hf_repo_id>)`\n"
)

model = Llama.from_pretrained(repo_id, filename, **llamacpp_model_params)

return LlamaCpp(model)

0 comments on commit e4575ed

Please sign in to comment.