Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the transformers integration #1344

Draft
wants to merge 1 commit into
base: v1.0
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 22 additions & 125 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,6 @@ def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs):
self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple["torch.LongTensor", "torch.LongTensor"]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: "torch.LongTensor") -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text

def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

Expand Down Expand Up @@ -137,62 +125,30 @@ def __init__(
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)

def forward(
self,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.

Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
def generate(
self, prompt: Union[str, List[str]], logits_processor, **inference_kwargs
):
from transformers import LogitsProcessorList, GenerationConfig

Returns
-------
The computed logits and the new cached key and value tensors.

"""
try:
import torch
except ImportError:
ImportError(
"The `torch` library needs to be installed to use `transformers` models."
)
assert 0 < input_ids.ndim < 3

if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)

with torch.inference_mode():
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
if isinstance(prompts, str):
prompts = [prompts]

return output.logits, output.past_key_values
input_ids, attention_mask = self.tokenizer.encode([prompts])

def __call__(
self,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> "torch.FloatTensor":
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
}

return next_token_logits, kv_cache
if logits_processor is not None:
logits_processor_list = LogitsProcessorList([logits_processor])
else:
logits_processor_list = None

output_ids = self.model.generate(
**inputs, generation_config=generation_config
)

def generate(
self,
Expand Down Expand Up @@ -223,27 +179,15 @@ def generate(
The generated text
"""
if isinstance(prompts, str):
# convert to 2d
input_ids, attention_mask = self.tokenizer.encode([prompts])
else:
input_ids, attention_mask = self.tokenizer.encode(prompts)
prompts = [prompts]

input_ids, attention_mask = self.tokenizer.encode([prompts])

inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
}
if (
"attention_mask"
not in inspect.signature(self.model.forward).parameters.keys()
):
del inputs["attention_mask"]

generation_kwargs = self._get_generation_kwargs(
prompts,
generation_parameters,
logits_processor,
sampling_parameters,
)
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)

# if single str input and single sample per input, convert to a 1D output
Expand Down Expand Up @@ -296,53 +240,6 @@ def stream(
output_group_ids = generated_ids.select(-1, i).unsqueeze(-1)
yield self._decode_generation(output_group_ids)

def _get_generation_kwargs(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> dict:
"""
Conert outlines generation parameters into model.generate kwargs
"""
from transformers import GenerationConfig, LogitsProcessorList, set_seed

max_new_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)
if max_new_tokens is None:
max_new_tokens = int(2**30)

# global seed, not desirable
if seed is not None:
set_seed(seed)

if logits_processor is not None:
logits_processor_list = LogitsProcessorList([logits_processor])
else:
logits_processor_list = None

generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
stop_strings=stop_at,
num_return_sequences=(num_samples or 1),
top_p=top_p,
top_k=top_k,
temperature=temperature,
do_sample=(sampler == "multinomial"),
num_beams=(num_samples if sampler == "beam_search" else 1),
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)

return dict(
logits_processor=logits_processor_list,
generation_config=generation_config,
tokenizer=self.tokenizer.tokenizer,
)

def _generate_output_seq(
self, prompts, inputs, generation_config, **generation_kwargs
):
Expand Down
Loading