Skip to content

Commit

Permalink
Merge pull request #1269 from AI-Hypercomputer:mohit/hf_tokenizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728787054
  • Loading branch information
maxtext authors committed Feb 19, 2025
2 parents 858da97 + 3fdd18f commit bea1cef
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 12 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ num_slices: -1
# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
tokenizer_path: "assets/tokenizer.llama2"
tokenizer_type: "sentencepiece"
tokenize_train_data: True # False if the dataset is pre-tokenized
tokenize_eval_data: True # False if the dataset is pre-tokenized
add_bos: True
Expand Down
4 changes: 2 additions & 2 deletions MaxText/input_pipeline/_input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def normalize_features(x, column_name):
return {"inputs": x[column_name], "targets": x[column_name]}


def get_tokenizer(tokenizer_path, add_bos, add_eos):
def get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token=""):
# Load tokenizer
tokenizer_model = tokenizer.build_tokenizer(tokenizer_path, add_bos, add_eos)
tokenizer_model = tokenizer.build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token)
return tokenizer_model


Expand Down
9 changes: 8 additions & 1 deletion MaxText/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_datasets(
def preprocessing_pipeline(
dataset,
tokenizer_path,
tokenizer_type: str,
global_batch_size: int,
global_mesh,
max_target_length: int,
Expand All @@ -91,6 +92,7 @@ def preprocessing_pipeline(
drop_remainder: bool = True,
prefetch_size=tf.data.experimental.AUTOTUNE,
use_dpo: bool = False,
hf_access_token: str = "",
):
"""pipeline for preprocessing TFDS dataset."""
if not use_dpo:
Expand All @@ -103,7 +105,7 @@ def preprocessing_pipeline(

data_column_names = data_column_names if use_dpo else ("inputs", "targets")
if tokenize:
tokenizer_model = _input_pipeline_utils.get_tokenizer(tokenizer_path, add_bos, add_eos)
tokenizer_model = _input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token)
data_keys = data_column_names
dataset = dataset.map(
lambda x: tokenizer.TokenizeOp(tokenizer=tokenizer_model, features=x, data_keys=data_keys),
Expand Down Expand Up @@ -176,6 +178,7 @@ def make_tfds_train_iterator(
train_iter = preprocessing_pipeline(
dataset=train_ds,
tokenizer_path=config.tokenizer_path,
tokenizer_type=config.tokenizer_type,
global_batch_size=config.global_batch_size_to_load,
global_mesh=global_mesh,
max_target_length=config.max_target_length,
Expand All @@ -186,6 +189,7 @@ def make_tfds_train_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
return train_iter

Expand All @@ -195,6 +199,7 @@ def make_tfds_eval_iterator(
global_mesh,
process_indices_eval,
):
"""load eval dataset, preprocess and return iterators"""
eval_ds = get_datasets(
dataset_name=config.eval_dataset_name,
data_split=config.eval_split,
Expand All @@ -207,6 +212,7 @@ def make_tfds_eval_iterator(
eval_iter = preprocessing_pipeline(
dataset=eval_ds,
tokenizer_path=config.tokenizer_path,
tokenizer_type=config.tokenizer_type,
global_batch_size=config.global_batch_size_to_load_eval,
global_mesh=global_mesh,
max_target_length=config.max_target_length,
Expand All @@ -217,6 +223,7 @@ def make_tfds_eval_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)

return eval_iter
4 changes: 3 additions & 1 deletion MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def make_c4_mlperf_train_iterator(
)
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})

sp_tokenizer = get_tokenizer(config.tokenizer_path, config.add_bos, config.add_eos)
sp_tokenizer = get_tokenizer(
config.tokenizer_path, config.tokenizer_type, config.add_bos, config.add_eos, config.hf_access_token
)
train_ds = preprocess_train_dataset(
train_ds,
sp_tokenizer=sp_tokenizer,
Expand Down
3 changes: 2 additions & 1 deletion MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def __init__(self, argv: list[str], **kwargs):

if raw_keys["log_config"]:
for k in keys:
max_logging.log(f"Config param {k}: {raw_keys[k]}")
if k != "hf_access_token":
max_logging.log(f"Config param {k}: {raw_keys[k]}")

@staticmethod
def user_init(raw_keys):
Expand Down
38 changes: 35 additions & 3 deletions MaxText/tests/tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import unittest
import pytest
import tensorflow_datasets as tfds
import subprocess
import os


Expand All @@ -38,7 +39,9 @@ def setUpClass(cls):
assets_path = "tests"
vocab_model_name = "test_tokenizer"
cls.tokenizer_path = os.path.join(assets_path, vocab_model_name)
cls.source_tokenizer = _input_pipeline_utils.get_tokenizer("../assets/tokenizer", add_bos=False, add_eos=False)
cls.source_tokenizer = _input_pipeline_utils.get_tokenizer(
"../assets/tokenizer", "sentencepiece", add_bos=False, add_eos=False
)
os.environ["TFDS_DATA_DIR"] = dataset_path
read_config = tfds.ReadConfig(
shuffle_seed=0,
Expand All @@ -51,7 +54,9 @@ def setUpClass(cls):
vocab_size=cls.vocab_size,
max_corpus_chars=cls.max_corpus_chars,
)
cls.test_tokenizer = _input_pipeline_utils.get_tokenizer(cls.tokenizer_path, add_bos=False, add_eos=False)
cls.test_tokenizer = _input_pipeline_utils.get_tokenizer(
cls.tokenizer_path, "sentencepiece", add_bos=False, add_eos=False
)

@classmethod
def tearDownClass(cls):
Expand All @@ -77,7 +82,10 @@ def setUpClass(cls):
dataset_name = "c4/en:3.0.1"
dataset_path = "gs://maxtext-dataset"
cls.source_tokenizer = _input_pipeline_utils.get_tokenizer(
"../assets/tokenizer_llama3.tiktoken", add_bos=False, add_eos=False
"../assets/tokenizer_llama3.tiktoken",
"tiktoken",
add_bos=False,
add_eos=False,
)
os.environ["TFDS_DATA_DIR"] = dataset_path
read_config = tfds.ReadConfig(
Expand All @@ -99,5 +107,29 @@ def test_detokenize(self):
self.assertEqual(np.asarray(self.source_tokenizer.decode(tokens)), np.asarray(text))


class HFTokenizerTest(unittest.TestCase):
"""Tests for HFTokenizer"""

@classmethod
def setUpClass(cls):
source = "gs://maxtext-gemma/huggingface/gemma2-2b"
destination = "../assets"
subprocess.run(
["gcloud", "storage", "cp", "-R", source, destination],
check=True,
)
cls.hf_tokenizer = _input_pipeline_utils.get_tokenizer(
"../assets/gemma2-2b", "huggingface", add_bos=False, add_eos=False
)
cls.sp_tokenizer = _input_pipeline_utils.get_tokenizer(
"../assets/tokenizer.gemma", "sentencepiece", add_bos=False, add_eos=False
)

@pytest.mark.tpu_only
def test_tokenize(self):
text = "This is a test"
self.assertTrue(np.array_equal(self.hf_tokenizer.encode(text), self.sp_tokenizer.encode(text)))


if __name__ == "__main__":
unittest.main()
35 changes: 31 additions & 4 deletions MaxText/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorflow as tf
import tensorflow_text as tftxt
import max_logging
import transformers
import tiktoken
from tiktoken.load import load_tiktoken_bpe

Expand Down Expand Up @@ -200,13 +201,39 @@ def decode(self, t: Sequence[int]) -> str:
return self.sp_tokenizer.detokenize(t)


def build_tokenizer(tokenizer_path, add_bos, add_eos):
class HFTokenizer:
"""
Tokenizing using huggingface tokenizer
"""

def __init__(self, model_path: str, add_bos: bool, add_eos: bool, hf_access_token: str):
max_logging.log(f"Loading HF tokenizer: {model_path}")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
add_bos_token=add_bos,
add_eos_token=add_eos,
token=hf_access_token,
)

def encode(self, s: str) -> List[int]:
return self.tokenizer.encode(s)

def decode(self, t: Sequence[int]) -> str:
return self.tokenizer.decode(t)


def build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token):
"""Loads the tokenizer at `tokenizer_path`"""
max_logging.log(f"Tokenizer path: {tokenizer_path}")
if "tiktoken" in tokenizer_path:
if tokenizer_type == "tiktoken":
assert "tiktoken" in tokenizer_path, f"Invalid tokenizer type: {tokenizer_type} chosen for {tokenizer_path}"
return TikTokenTokenizer(tokenizer_path, add_bos, add_eos)
else:
elif tokenizer_type == "huggingface":
return HFTokenizer(tokenizer_path, add_bos, add_eos, hf_access_token)
elif tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path, add_bos, add_eos)
else:
raise ValueError(f"Invalid tokenizer_type:{tokenizer_type} chosen in config")


def TokenizeOp(tokenizer, features: Features, data_keys: Iterable[str] = ("inputs", "targets")) -> Features:
Expand All @@ -220,7 +247,7 @@ def _process_string(string_tensor):
return [modified_string]

for k in data_keys:
if isinstance(tokenizer, TikTokenTokenizer):
if isinstance(tokenizer, (TikTokenTokenizer, HFTokenizer)):
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
elif isinstance(tokenizer, SentencePieceTokenizer):
features[k] = tokenizer.encode(features[k])
Expand Down

0 comments on commit bea1cef

Please sign in to comment.