From 13c39ac657e4e5123739fb4549869c8ff8ee0bf7 Mon Sep 17 00:00:00 2001 From: Sanjeev Satheesh Date: Thu, 10 Oct 2024 18:14:15 -0700 Subject: [PATCH] ADLR/megatron-lm!1909 - Standard interface for getting offsets from tokenizers --- Dockerfile.ci | 1 + Dockerfile.ci.dev | 1 + megatron/core/datasets/megatron_tokenizer.py | 15 ++ .../inference/text_generation/tokenization.py | 36 ++-- megatron/training/tokenizer/tokenizer.py | 25 +++ tests/unit_tests/test_tokenizer.py | 193 ++++++++++++++++++ 6 files changed, 251 insertions(+), 20 deletions(-) create mode 100644 tests/unit_tests/test_tokenizer.py diff --git a/Dockerfile.ci b/Dockerfile.ci index fa13c48fd4..f1b693b9d9 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -36,6 +36,7 @@ pytest-cov \ pytest_mock \ pytest-random-order \ sentencepiece \ +tiktoken \ wrapt \ zarr \ wandb \ diff --git a/Dockerfile.ci.dev b/Dockerfile.ci.dev index fa13c48fd4..f1b693b9d9 100644 --- a/Dockerfile.ci.dev +++ b/Dockerfile.ci.dev @@ -36,6 +36,7 @@ pytest-cov \ pytest_mock \ pytest-random-order \ sentencepiece \ +tiktoken \ wrapt \ zarr \ wandb \ diff --git a/megatron/core/datasets/megatron_tokenizer.py b/megatron/core/datasets/megatron_tokenizer.py index 8adeff418b..84f3546cf3 100644 --- a/megatron/core/datasets/megatron_tokenizer.py +++ b/megatron/core/datasets/megatron_tokenizer.py @@ -57,6 +57,21 @@ def detokenize(self, ids: numpy.ndarray) -> str: """ raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__)) + def offsets(self, ids: list[int], text: str) -> list[int]: + """Convert embedding ids to text offsets + + Args: + ids (list[int]): The ids to convert + text (str): The text to convert + + Returns: + list[int]: The converted offsets + + Raises: + NotImplementedError: Non-abstract, optional method + """ + raise NotImplementedError("{} has no method 'offsets'".format(type(self).__name__)) + @property @abstractmethod def vocab(self): diff --git a/megatron/inference/text_generation/tokenization.py b/megatron/inference/text_generation/tokenization.py index 36bec4d50e..e58e991305 100644 --- a/megatron/inference/text_generation/tokenization.py +++ b/megatron/inference/text_generation/tokenization.py @@ -24,28 +24,24 @@ def detokenize_generations(tokens_gpu_tensor, lengths = lengths_gpu_tensor.cpu().numpy().tolist() for sequence_tokens, length in zip(tokens, lengths): sequence_tokens = sequence_tokens[:length] - prompts_plus_generations.append( - tokenizer.detokenize(sequence_tokens)) + detok_str = tokenizer.detokenize(sequence_tokens) + prompts_plus_generations.append(detok_str) if detokenize_segments: - words = [] - for token in sequence_tokens: - if args.tokenizer_type in ['SentencePieceTokenizer', - 'GPTSentencePieceTokenizer', - 'HuggingFaceTokenizer', - 'Llama2Tokenizer']: - word = tokenizer.decoder[token] - elif args.tokenizer_type == 'TikTokenizer': - word = tokenizer.detokenize([token]) - elif args.tokenizer_type in ['Llama3Tokenizer', 'MistralTokenizer']: - word = tokenizer.decode([token]) - elif args.tokenizer_type == 'NullTokenizer': - word = str(token) - else: + try: + offsets = tokenizer.offsets(sequence_tokens, detok_str) + words = [ + detok_str[start:end] + for start, end in zip(offsets, offsets[1:] + [len(detok_str)]) + ] + except NotImplementedError: + words = [] + for token in sequence_tokens: word = tokenizer.tokenizer.decoder[token] - word = bytearray( - [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( - 'utf-8', errors='replace') - words.append(word) + word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( + "utf-8", errors="replace" + ) + words.append(word) + prompts_plus_generations_segments.append(words) return tokens, prompts_plus_generations, prompts_plus_generations_segments diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py index 226ae1e799..aa5e410076 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -128,6 +128,18 @@ def tokenize(self, text, **kwargs): def detokenize(self, token_ids, **kwargs): return self._tokenizer.decode(token_ids, **kwargs) + def offsets(self, ids: list[int], text: str) -> list[int]: + retok_ids: "transformers.BatchEncoding" = self._tokenizer(text) + offsets, next_start_idx = [], 0 + for i in range(len(ids)): + span = retok_ids.token_to_chars(i) + if span is not None: + offsets.append(span.start) + next_start_idx = span.end + else: + offsets.append(next_start_idx) + return offsets + @property def eod(self): return self._tokenizer.eos_token_id @@ -426,6 +438,9 @@ def detokenize(self, ids): text += self.tokenizer.decode_ids(ids[last_i:]) return text + def offsets(self, ids: list[int], text: str) -> list[int]: + return [p.begin for p in self.tokenizer.decode_ids_as_immutable_proto(ids).pieces] + @property def cls(self): return self._cls_id @@ -687,6 +702,9 @@ def tokenize(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: def detokenize(self, tokens: List[int]) -> str: return self._model.decode(tokens) + def offsets(self, ids: list[int], text: str) -> list[int]: + return self._model.decode_with_offsets(ids)[1] + @property def vocab_size(self) -> int: return self._vocab_size @@ -713,6 +731,13 @@ def detokenize(self, ids): text = [str(x) for x in ids] return ' '.join(text) + def offsets(self, ids: list[int], text: str) -> list[int]: + offsets, start_idx = [], 0 + for id_ in ids: + offsets.append(start_idx) + start_idx += 1 + len(str(id_)) + return offsets + @property def vocab_size(self): return self._vocab_size_without_eod + 1 diff --git a/tests/unit_tests/test_tokenizer.py b/tests/unit_tests/test_tokenizer.py new file mode 100644 index 0000000000..13e222953b --- /dev/null +++ b/tests/unit_tests/test_tokenizer.py @@ -0,0 +1,193 @@ +import base64 +import json +from argparse import Namespace +from pathlib import Path + +import pytest +import requests + +from megatron.training import tokenizer +from megatron.training.tokenizer.gpt2_tokenization import PRETRAINED_VOCAB_ARCHIVE_MAP + +TOKENIZER_DIR = Path("~/data/tokenizers").expanduser() + +# Copied over from test_preprocess_data.py +__LOCAL_GPT2_VOCAB = "/home/gitlab-runner/data/gpt3_data/gpt2-vocab.json" + + +def offsets_to_substrs(offsets, string): + return [string[start:end] for start, end in zip([0] + offsets, offsets + [len(string)])] + + +def local_test_specs(): + return [ + Namespace( + rank=0, + tensor_model_parallel_size=8, + make_vocab_size_divisible_by=128, + tokenizer_type="GPTSentencePieceTokenizer", + tokenizer_model=f"{TOKENIZER_DIR}/nemotron_2_256k.model", + ), + Namespace( + rank=0, + vocab_size=131072, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=8, + tokenizer_type="TikTokenizer", + tokenizer_model=f"{TOKENIZER_DIR}/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json", + tiktoken_pattern="v2", + tiktoken_num_special_tokens=1000, + tiktoken_special_tokens=["", "", ""], + ), + Namespace( + rank=0, + vocab_size=131072, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=8, + tokenizer_type="TikTokenizer", + tokenizer_model=f"{TOKENIZER_DIR}/multiMixV5_fix_default_500000_128k.vocab.json", + tiktoken_pattern="v1", + tiktoken_num_special_tokens=1000, + tiktoken_special_tokens=["", "", ""], + ), + Namespace( + rank=0, + vocab_size=128000, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=8, + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model="meta-llama/Llama-2-7b-hf", + ), + Namespace( + rank=0, + vocab_size=128000, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=8, + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model="meta-llama/Meta-Llama-3.1-8B", + ), + ] + + +@pytest.fixture(scope="session") +def gpt2_tiktok_vocab(tmp_path_factory): + + if Path(__LOCAL_GPT2_VOCAB).exists(): + with open(__LOCAL_GPT2_VOCAB, "r", encoding="utf-8") as reader: + gpt2_vocab = json.load(reader) + else: + gpt2_vocab = json.loads(requests.get(PRETRAINED_VOCAB_ARCHIVE_MAP["gpt2"]).content) + + N = 256 + tiktok_vocab = [ + {"token_bytes": base64.b64encode(bytes([i])).decode("utf-8"), "token_str": str(i)} + for i in range(N) + ] + tiktok_vocab_bytes = {x["token_bytes"] for x in tiktok_vocab} + + tiktok_vocab += [ + {"token_bytes": base64.b64encode(token.encode('utf-8')).decode("utf-8"), "token_str": token} + for token in gpt2_vocab + if base64.b64encode(token.encode('utf-8')).decode("utf-8") not in tiktok_vocab_bytes + ] + + for i, entry in enumerate(tiktok_vocab): + entry["rank"] = i + + for i, x in enumerate(tiktok_vocab): + assert x.keys() == {"rank", "token_bytes", "token_str"} + assert x["rank"] == i + merge = base64.b64decode(x["token_bytes"]) + assert i >= 256 or merge == bytes([i]), f"{i} {merge} {bytes([i])}" + + file_name = tmp_path_factory.mktemp("data") / "gpt2_vocab.json" + with open(file_name, "w") as f: + json.dump(tiktok_vocab, f) + + return Namespace( + rank=0, + vocab_size=32768, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=8, + tokenizer_type="TikTokenizer", + tokenizer_model=str(file_name), + tiktoken_pattern="v1", + tiktoken_num_special_tokens=1000, + tiktoken_special_tokens=["", "", ""], + ) + + +def specs(): + if TOKENIZER_DIR.exists(): + return local_test_specs() + return [] + + +@pytest.mark.parametrize("args", specs()) +def test_tokenizer(args): + tok = tokenizer.build_tokenizer(args) + run_tokenizer_tests(tok) + + +def test_gpt2_tiktok_tokenizer(gpt2_tiktok_vocab): + tok = tokenizer.build_tokenizer(gpt2_tiktok_vocab) + run_tokenizer_tests(tok) + + +def run_tokenizer_tests(tok): + string1 = ( + "The following are multiple choice questions (with answers) about college biology.\n" + "Monoclonal antisera are distinguished from polyclonal antisera in which of the " + "following ways?\n" + "A. Each type of antibody in a monoclonal antiserum reacts against a single region of " + "a single antigen; each type of antibody in a polyclonal antiserum reacts against " + "multiple regions of different antigens.\n" + "B. A monoclonal antibody reacts against multiple regions of a single antigen; a " + "polyclonal antibody reacts against a single region of related antigens.\n" + "C. A monoclonal antiserum contains antibodies secreted from the descendants of a " + "single B lymphocyte; a polyclonal antiserum contains antibodies secreted from the " + "descendants of different B lymphocytes.\n" + "D. A monoclonal antiserum contains antibodies secreted from the descendants of a " + "single B lymphocyte; a polyclonal antiserum contains antibodies secreted from the " + "descendants of both B and T lymphocytes.\n" + "Answer: C" + ) + string2 = "Жизнь прекрасна и удивительна" + string3 = "お誕生日おめでとう" + strings = [string1, string2, string3] + + for test_string in strings: + toks = tok.tokenize(test_string) + offsets = tok.offsets(toks, test_string) + dec = offsets_to_substrs(offsets, test_string) + detok_str = ''.join(dec) + # the following is not necessarily true by construction above, + # since the many tokenizers may operate at the byte level and not + # only at the character level. + assert ( + detok_str == test_string + ), f"Detokenized string {detok_str} does not match original {test_string}" + assert len(toks) == len( + offsets + ), f"Tokenized string {toks} does not match original {offsets}" + + +def test_null_tokenizer(): + args = Namespace( + tokenizer_type="NullTokenizer", + rank=0, + vocab_size=128000, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=8, + ) + tok = tokenizer.build_tokenizer(args) + test_string = "1 23 456 789" + toks = tok.tokenize(test_string) + offsets = tok.offsets(toks, test_string) + dec = offsets_to_substrs(offsets, test_string) + detok_str = ''.join(dec) + + assert ( + detok_str == test_string + ), f"Detokenized string {detok_str} does not match original {test_string}" + assert len(toks) == len(offsets), f"Tokenized string {toks} does not match original {offsets}"