Skip to content

Commit

Permalink
feat: chat template impl 1
Browse files Browse the repository at this point in the history
probably to be rewritten to use a ChatTemplatePromptPipeline
  • Loading branch information
zhudotexe committed Sep 10, 2024
1 parent b751cb9 commit 318e3c0
Showing 1 changed file with 99 additions and 56 deletions.
155 changes: 99 additions & 56 deletions kani/engines/huggingface/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import warnings
from collections import defaultdict
from functools import cached_property
from threading import Thread
from typing import AsyncIterable

Expand All @@ -8,6 +10,7 @@
from kani.models import ChatMessage
from kani.prompts.pipeline import PromptPipeline
from ..base import BaseCompletion, BaseEngine, Completion
from ... import ChatRole

try:
import torch
Expand Down Expand Up @@ -51,25 +54,28 @@ def __init__(
max_context_size: int = None,
prompt_pipeline: PromptPipeline[str | torch.Tensor] = None,
*,
# hf args
token=None,
device: str | None = None,
tokenizer_kwargs: dict = None,
model_load_kwargs: dict = None,
# kani args
token_reserve: int = 0,
**hyperparams,
):
"""
:param model_id: The ID of the model to load from HuggingFace.
:param max_context_size: The context size of the model. If not given, will be set from the model's config.
:param prompt_pipeline: The pipeline to translate a list of kani ChatMessages into the model-specific chat
format (see :class:`.PromptPipeline`).
format (see :class:`.PromptPipeline`). If not passed, uses the Hugging Face chat template if available.
:param token: The Hugging Face access token (for gated models). Pass True to load from huggingface-cli.
:param device: The hardware device to use. If not specified, uses CUDA if available; otherwise uses CPU.
:param tokenizer_kwargs: Additional arguments to pass to ``AutoTokenizer.from_pretrained()``.
:param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``.
:param hyperparams: Additional arguments to supply the model during generation.
:param token_reserve: The number of tokens to reserve for internal engine mechanisms (e.g. if there is a
generation template after the last user message).
generation template after the last user message). If not passed, kani will attempt to infer this from a
prompt pipeline.
"""
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
Expand Down Expand Up @@ -113,14 +119,30 @@ def __init__(
elif self.max_context_size > 1e20:
warnings.warn(
f"The inferred max context size of this model is extremely large ({self.max_context_size}). This"
" may mean that the model has not configured their model_max_len correctly (or you are still using"
" my code in 2050). Please pass the `max_context_size` arg to use the correct model size."
" may mean that the model has not configured their model_max_len correctly. Please pass the"
" `max_context_size` arg to use the correct model size."
)

# infer the token reserve from the pipeline
if self.token_reserve == 0 and self.pipeline:
self.token_reserve = self._infer_token_reserve()

# no pipeline estimation caches
self._padding_len_by_role: dict[ChatRole, int] = defaultdict(lambda: 0)
if self.token_reserve == 0 and not self.pipeline:
self.token_reserve = self._chat_template_infer_token_reserve()

_chat_template_dummy_msg = {"role": "user", "content": "dummy"}

@cached_property
def _chat_template_dummy_len(self) -> int:
return len(self.tokenizer.apply_chat_template([self._chat_template_dummy_msg], add_generation_prompt=False))

def _chat_template_infer_token_reserve(self):
"""If token_reserve is not set and we have a pipeline, infer it."""
full_len = self.tokenizer.apply_chat_template([self._chat_template_dummy_msg], add_generation_prompt=True)
return full_len - self._chat_template_dummy_len

def _infer_token_reserve(self):
"""If token_reserve is not set and we have a pipeline, infer it."""
prompt = self.pipeline.execute([], for_measurement=True)
Expand All @@ -130,43 +152,55 @@ def _infer_token_reserve(self):
tokenized = self.tokenizer.encode(prompt, add_special_tokens=False)
return len(tokenized)

def _chat_template_message_len(self, message: ChatMessage) -> int:
"""Estimate the message length of a single message based off the chat template."""
_ensure_chat_template(self.tokenizer)
conversation = [{"role": message.role.value, "content": message.text}]
try:
return len(self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False))
except TemplateError:
# the template probably enforces user/assistant,
# return a best-effort estimate based on the cached additions to messages of this role
raw_tok_len = len(self.tokenizer.encode(message.text, add_special_tokens=False))
return raw_tok_len + self._padding_len_by_role[message.role]

def message_len(self, message: ChatMessage) -> int:
"""Return the length, in tokens, of the given chat message.
The HuggingEngine's default implementation renders the message with ``apply_chat_template`` if no
``prompt_pipeline`` is supplied.
"""
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
)
return self._chat_template_message_len(message)
# raise NotImplementedError(
# "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
# )
prompt = self.pipeline.execute([message], for_measurement=True)
if isinstance(prompt, torch.Tensor):
return len(prompt[0])
# prompt str to tokens
tokenized = self.tokenizer.encode(prompt, add_special_tokens=False)
return len(tokenized)

# def message_len(self, message: ChatMessage) -> int:
# """Return the length, in tokens, of the given chat message.
#
# The HuggingEngine's default implementation renders the message with `apply_chat_template`.
# """
# _ensure_chat_template(self.tokenizer)
# conversation = [{"role": message.role.value, "content": message.text}]
# try:
# return len(self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False))
# except TemplateError:
# # the template probably enforces user/assistant,
# # HACK: let's try a dummy user message then an assistant one, and count the diff
# conversation = [{"role": "user", "content": "a"}]
# dummy_len = len(self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False))
# conversation.append({"role": message.role.value, "content": message.text})
# two_len = len(self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False))
# return two_len - dummy_len
def _chat_template_function_token_reserve(self, functions: list[AIFunction]) -> int:
"""Estimate the function token reserve based off the chat template."""
_ensure_chat_template(self.tokenizer)
tools = [f.json_schema for f in functions]
full_len = len(
self.tokenizer.apply_chat_template(
[self._chat_template_dummy_msg], tools=tools, add_generation_prompt=False
)
)
return full_len - self._chat_template_dummy_len

def function_token_reserve(self, functions: list[AIFunction]) -> int:
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
)
return self._chat_template_function_token_reserve(functions)
# raise NotImplementedError(
# "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
# )
prompt = self.pipeline.execute([], functions, for_measurement=True)
if isinstance(prompt, torch.Tensor):
toklen = len(prompt[0])
Expand All @@ -185,6 +219,21 @@ def function_token_reserve(self, functions: list[AIFunction]) -> int:

return toklen

def _chat_template_build_prompt(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None
) -> str | torch.Tensor:
"""Given the list of messages from kani, build either a single string representing the prompt for the model,
or build the token tensor.
The default implementation uses the model tokenizer's `apply_chat_template` method.
"""
_ensure_chat_template(self.tokenizer)
conversation = [{"role": msg.role.value, "content": msg.text} for msg in messages]
tools = [f.json_schema for f in functions]
return self.tokenizer.apply_chat_template(
conversation, tools=tools, add_generation_prompt=True, return_tensors="pt"
)

def build_prompt(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None
) -> str | torch.Tensor:
Expand All @@ -195,34 +244,15 @@ def build_prompt(
The default behaviour is to call the supplied pipeline.
"""
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
)
prompt = self.pipeline(messages, functions)
prompt = self._chat_template_build_prompt(messages, functions)
# raise NotImplementedError(
# "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
# )
else:
prompt = self.pipeline(messages, functions)
log.debug(f"BUILT PROMPT: {prompt}")
return prompt

# def build_prompt(
# self, messages: list[ChatMessage], functions: list[AIFunction] | None = None
# ) -> str | torch.Tensor:
# """Given the list of messages from kani, build either a single string representing the prompt for the model,
# or build the token tensor.
#
# The default implementation uses the model tokenizer's `apply_chat_template` method.
# """
# _ensure_chat_template(self.tokenizer)
# conversation = [{"role": msg.role.value, "content": msg.text} for msg in messages]
# try:
# return self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
# except TemplateError:
# # the template probably enforces user/assistant,
# # HACK: let's try a dummy user message then the assistant one, and strip the len of the dummy off (pain)
# conv2 = [{"role": "user", "content": "a"}]
# dummy_len = len(self.tokenizer.apply_chat_template(conv2, add_generation_prompt=False))
# conv2.extend(conversation)
# toks = self.tokenizer.apply_chat_template(conv2, add_generation_prompt=True, return_tensors="pt")
# return toks[dummy_len:]

def _get_generate_args(self, prompt: str | torch.Tensor, **hyperparams):
"""Internal method to build common params for the generate call"""
if isinstance(prompt, str):
Expand Down Expand Up @@ -273,9 +303,10 @@ async def predict(
# decode to tokens
# the completion shouldn't include the prompt or stop token
content = self.tokenizer.decode(output[0][input_len:], **decode_kwargs).strip()
return Completion(
ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=len(output[0]) - (input_len + 1)
)
# attempt to estimate the assistant message padding if not set
output_len = len(output[0]) - (input_len + 1)
self._chat_template_estimate_padding(content=content, n_tokens_generated=output_len, role=ChatRole.ASSISTANT)
return Completion(ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=output_len)

async def stream(
self,
Expand Down Expand Up @@ -326,17 +357,29 @@ def thread_target():

# yield a completion with usage stats
content = "".join(yielded_tokens)
# attempt to estimate the assistant message padding if not set
output_len = len(yielded_tokens)
self._chat_template_estimate_padding(content=content, n_tokens_generated=output_len, role=ChatRole.ASSISTANT)
yield Completion(
message=ChatMessage.assistant(content=content.strip()),
prompt_tokens=input_len,
completion_tokens=len(output_toks[0]) - (input_len + 1),
)

def _chat_template_estimate_padding(self, content: str, n_tokens_generated: int, role: ChatRole):
"""Estimate the number of padding tokens needed for"""
if self.pipeline or self._padding_len_by_role[role]:
return
log.debug(f"Estimating {role} token padding from chat template...")
reencoded_len = len(self.tokenizer.encode(content, skip_special_tokens=True))
self._padding_len_by_role[role] = max(n_tokens_generated - reencoded_len, 0)
log.debug(f"{n_tokens_generated=}, {reencoded_len=}, padding estimate={n_tokens_generated - reencoded_len}")


def _ensure_chat_template(tokenizer):
if not hasattr(tokenizer, "apply_chat_template"):
raise MissingModelDependencies(
"To use the HuggingEngine with built-in chat templates requires `transformers>=4.34.0`. You currently"
f" have `transformers=={transformers.__version__}`. Please update your transformers with `pip install"
" -U transformers` or use a concrete implementation of the HuggingEngine."
" -U transformers` or supply a `prompt_template` to this HuggingEngine."
)

0 comments on commit 318e3c0

Please sign in to comment.