Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow saving and loading multiple "raw" chat template files #36588

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
92 changes: 78 additions & 14 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
from contextlib import contextmanager
from dataclasses import dataclass
from inspect import isfunction
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np
from huggingface_hub import list_repo_tree
from huggingface_hub.errors import EntryNotFoundError # Should this be refactored into a util instead?
from packaging import version

from . import __version__
Expand Down Expand Up @@ -146,6 +149,7 @@ class EncodingFast:
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
CHAT_TEMPLATE_FILE = "chat_template.jinja"
CHAT_TEMPLATE_DIR = "additional_chat_templates"

# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
Expand Down Expand Up @@ -1981,6 +1985,31 @@ def from_pretrained(
"tokenizer_file": FULL_TOKENIZER_FILE,
"chat_template_file": CHAT_TEMPLATE_FILE,
}

# This block looks for any extra chat template files
if is_local:
template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
if template_dir.is_dir():
for template_file in template_dir.glob("*.jinja"):
template_name = template_file.name.removesuffix(".jinja")
additional_files_names[f"chat_template_{template_name}"] = (
f"{CHAT_TEMPLATE_DIR}/{template_file.name}"
)
else:
try:
for template_file in list_repo_tree(
pretrained_model_name_or_path,
path_in_repo=CHAT_TEMPLATE_DIR,
recursive=False,
revision=revision,
):
if not template_file.path.endswith(".jinja"):
continue
template_name = template_file.path.split("/")[-1].removesuffix(".jinja")
additional_files_names[f"chat_template_{template_name}"] = template_file.path
except EntryNotFoundError:
pass # No template dir means no template files

vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
Expand Down Expand Up @@ -2129,11 +2158,24 @@ def _from_pretrained(
config_tokenizer_class = None
init_kwargs = init_configuration

# If an independent chat template file exists, it takes priority over template entries in the tokenizer config
# If independent chat template file(s) exist, they take priority over template entries in the tokenizer config
chat_templates = {}
chat_template_file = resolved_vocab_files.pop("chat_template_file", None)
extra_chat_templates = [key for key in resolved_vocab_files if key.startswith("chat_template_")]
if chat_template_file is not None:
with open(chat_template_file) as chat_template_handle:
init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config
chat_templates["default"] = chat_template_handle.read()
for extra_chat_template in extra_chat_templates:
template_file = resolved_vocab_files.pop(extra_chat_template, None)
if template_file is None:
continue # I think this should never happen, but just in case
template_name = extra_chat_template.removeprefix("chat_template_")
with open(template_file) as chat_template_handle:
chat_templates[template_name] = chat_template_handle.read()
if len(chat_templates) == 1 and "default" in chat_templates:
init_kwargs["chat_template"] = chat_templates["default"]
elif chat_templates:
init_kwargs["chat_template"] = chat_templates

if not _is_local:
if "auto_map" in init_kwargs:
Expand Down Expand Up @@ -2430,6 +2472,9 @@ def save_pretrained(
chat_template_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE
)
chat_template_dir = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR
)

tokenizer_config = copy.deepcopy(self.init_kwargs)

Expand All @@ -2448,22 +2493,43 @@ def save_pretrained(
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
tokenizer_config.update(self.extra_special_tokens)

saved_raw_chat_template = False
saved_raw_chat_template_files = []
if self.chat_template is not None:
if isinstance(self.chat_template, dict):
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
# They will be reconstructed as a single dict during loading.
# We're trying to discourage chat template dicts, and they are always
# saved in the config, never as single files.
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
elif kwargs.get("save_raw_chat_template", False):
if kwargs.get("save_raw_chat_template", False) and isinstance(self.chat_template, str):
# New format for single templates is to save them as chat_template.jinja
with open(chat_template_file, "w", encoding="utf-8") as f:
f.write(self.chat_template)
saved_raw_chat_template = True
logger.info(f"chat template saved in {chat_template_file}")
saved_raw_chat_template_files.append(chat_template_file)
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
elif kwargs.get("save_raw_chat_template", False) and isinstance(self.chat_template, dict):
# New format for multiple templates is to save the default as chat_template.jinja
# and the other templates in the chat_templates/ directory
for template_name, template in self.chat_template.items():
if template_name == "default":
with open(chat_template_file, "w", encoding="utf-8") as f:
f.write(self.chat_template["default"])
logger.info(f"chat template saved in {chat_template_file}")
saved_raw_chat_template_files.append(chat_template_file)
else:
Path(chat_template_dir).mkdir(exist_ok=True)
template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja")
with open(template_filepath, "w", encoding="utf-8") as f:
f.write(template)
logger.info(f"chat template saved in {template_filepath}")
saved_raw_chat_template_files.append(template_filepath)
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
elif isinstance(self.chat_template, dict):
# Legacy format for multiple templates:
# chat template dicts are saved to the config as lists of dicts with fixed key names.
# They will be reconstructed as a single dict during loading.
# We're trying to discourage chat template dicts, and they are always
# saved in the config, never as single files.
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
else:
# Legacy format for single templates: Just make them a key in tokenizer_config.json
tokenizer_config["chat_template"] = self.chat_template

if len(self.init_inputs) > 0:
Expand Down Expand Up @@ -2518,9 +2584,7 @@ def save_pretrained(
f.write(out_str)
logger.info(f"Special tokens file saved in {special_tokens_map_file}")

file_names = (tokenizer_config_file, special_tokens_map_file)
if saved_raw_chat_template:
file_names += (chat_template_file,)
file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files)

save_files = self._save_pretrained(
save_directory=save_directory,
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,9 @@ def push_to_hub(
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
save_raw_chat_template = deprecated_kwargs.pop(
"save_raw_chat_template", None
) # TODO: This is only used for testing and should be removed once save_raw_chat_template becomes the default
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
Expand Down Expand Up @@ -885,7 +888,15 @@ def push_to_hub(
files_timestamps = self._get_files_timestamps(work_dir)

# Save all files.
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
if save_raw_chat_template:
self.save_pretrained(
work_dir,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization,
save_raw_chat_template=True,
)
else:
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)

# Update model card if needed:
model_card.save(os.path.join(work_dir, "README.md"))
Expand Down
30 changes: 19 additions & 11 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,20 +1626,28 @@ def test_chat_template_dict_saving(self):
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
for save_raw_chat_template in (True, False):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
tokenizer.chat_template = {"default": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[
{"name": "template1", "template": "{{'a'}}"},
{"name": "template2", "template": "{{'b'}}"},
],
)
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
if save_raw_chat_template:
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
self.assertNotIn("chat_template", config_dict)
self.assertTrue(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir_name, "additional_chat_templates/template2.jinja"))
)
else:
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[
{"name": "default", "template": "{{'a'}}"},
{"name": "template2", "template": "{{'b'}}"},
],
)
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
Expand Down