Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inverse chat templating #33321

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
41 changes: 40 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@
requires_backends,
to_py_obj,
)
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices
from .utils.chat_template_utils import (
_compile_inverse_template,
_compile_jinja_template,
_render_with_assistant_indices,
)
from .utils.import_utils import PROTOBUF_IMPORT_ERROR


Expand Down Expand Up @@ -145,6 +149,7 @@ class EncodingFast:
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
INVERSE_TEMPLATE_FILE = "inverse_template.jinja"

# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
Expand Down Expand Up @@ -1631,6 +1636,8 @@ def __init__(self, **kwargs):
# we reconstruct that into a single dict while loading them.
self.chat_template = {template["name"]: template["template"] for template in self.chat_template}

self.inverse_template = kwargs.pop("inverse_template", None)

super().__init__(**kwargs)

@property
Expand Down Expand Up @@ -1916,6 +1923,24 @@ def apply_chat_template(
else:
return rendered

def apply_inverse_template(self, chat: str, inverse_template: Optional[str] = None, skip_json_load: bool = False):
if inverse_template is None:
if self.inverse_template is not None:
inverse_template = self.inverse_template
else:
raise ValueError(
"Cannot use apply_inverse_template() because tokenizer.inverse_template is not set! Please set "
"the tokenizer.inverse_template attribute to a valid Jinja template string."
)
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = _compile_inverse_template(inverse_template)

template_out = compiled_template.render(chat=chat)
if skip_json_load:
return template_out
else:
return json.loads(template_out)

def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
"""
Retrieve the chat template string used for tokenizing chat messages. This template is used
Expand Down Expand Up @@ -2121,6 +2146,7 @@ def from_pretrained(
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
"tokenizer_file": FULL_TOKENIZER_FILE,
"inverse_template": INVERSE_TEMPLATE_FILE,
}
vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files:
Expand Down Expand Up @@ -2241,6 +2267,7 @@ def _from_pretrained(
from_slow = kwargs.get("from_slow", False)
gguf_file = kwargs.get("gguf_file", None)
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
inverse_template_file = resolved_vocab_files.pop("inverse_template", None)

# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
# loaded directly from the GGUF file.
Expand Down Expand Up @@ -2342,6 +2369,10 @@ def _from_pretrained(
f" from is '{cls.__name__}'."
)

if inverse_template_file is not None:
with open(inverse_template_file) as chat_template_handle:
init_kwargs["inverse_template"] = chat_template_handle.read()

# Update with newly provided kwargs
init_kwargs.update(kwargs)

Expand Down Expand Up @@ -2577,6 +2608,10 @@ def save_pretrained(
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
)

inverse_chat_template_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + INVERSE_TEMPLATE_FILE
)

tokenizer_config = copy.deepcopy(self.init_kwargs)

# Let's save the init kwargs
Expand All @@ -2599,6 +2634,10 @@ def save_pretrained(
else:
tokenizer_config["chat_template"] = self.chat_template

if self.inverse_template is not None:
with open(inverse_chat_template_file, "w", encoding="utf-8") as f:
f.write(self.inverse_template)

if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
Expand Down
87 changes: 75 additions & 12 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import re
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
Expand All @@ -28,7 +29,7 @@
if is_jinja_available():
import jinja2
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2.sandbox import ImmutableSandboxedEnvironment, SandboxedEnvironment
else:
jinja2 = None

Expand Down Expand Up @@ -406,21 +407,83 @@ def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)

def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

def strftime_now(format):
return datetime.now().strftime(format)

jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
return jinja_env.from_string(chat_template)


@lru_cache
def _compile_inverse_template(inverse_template):
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

jinja_env = SandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[jinja2.ext.loopcontrols])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something here, but do you really need a mutable sandbox? Your testing template doesn't look like it does...

jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["finditer"] = finditer
jinja_env.globals["sort_by_group_start"] = sort_by_group_start
jinja_env.filters["tojson"] = tojson
jinja_env.globals["json_loads"] = json_loads
jinja_env.globals["IGNORECASE"] = re.IGNORECASE
jinja_env.globals["MULTILINE"] = re.MULTILINE
jinja_env.globals["DOTALL"] = re.DOTALL
return jinja_env.from_string(inverse_template)


# Functions for the Jinja environments below this line


def finditer(pattern, string, flags=0, add_tag=None, add_tag_from_group=None):
@dataclass
class NewMatchObject:
group: List[str]
group_starts: List[int]
tag: Optional[str]

if add_tag is not None and add_tag_from_group is not None:
raise jinja2.exceptions.TemplateError("Cannot use add_tag and add_tag_from_group at the same time!")
out = []
for match in re.finditer(pattern, string, flags=flags):
# groups() by default does not include group(0), the whole string
# so we add it in manually to make things match up
groups = [match.group(0)] + list(match.groups())
group_starts = [match.start(i) for i in range(len(groups))]
if add_tag_from_group is not None:
add_tag = groups[add_tag_from_group]
out.append(NewMatchObject(group=groups, group_starts=group_starts, tag=add_tag))
return out


def sort_by_group_start(matches, group_idx=0, group_idx_by_tag=None):
if group_idx_by_tag is None:
group_idx_by_tag = {}

def sort_key(match):
# Use the idx specific to this tag if present, or the global group_idx if not
idx = group_idx_by_tag.get(match.tag, group_idx)
return match.group_starts[idx]

return sorted(matches, key=sort_key)


def json_loads(string):
return json.loads(string)


def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)


def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)


def strftime_now(format_str):
return datetime.now().strftime(format_str)
Loading
Loading