From bea27c5d6ff50a30476866c90774f80bf35a0b3f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 16 Apr 2024 01:28:49 -0700 Subject: [PATCH 1/7] Add a replace operation --- CMakeLists.txt | 3 +- mlx/data/Dataset.cpp | 26 +++++++++++++ mlx/data/Dataset.h | 12 ++++++ mlx/data/core/Utils.cpp | 73 +++++++++++++++++++++++++++++++++++- mlx/data/core/Utils.h | 8 +++- mlx/data/op/Replace.cpp | 30 +++++++++++++++ mlx/data/op/Replace.h | 30 +++++++++++++++ python/src/wrap_dataset.h | 36 ++++++++++++++++++ python/tests/test_buffer.py | 8 +++- python/tests/test_replace.py | 24 ++++++++++++ 10 files changed, 244 insertions(+), 6 deletions(-) create mode 100644 mlx/data/op/Replace.cpp create mode 100644 mlx/data/op/Replace.h create mode 100644 python/tests/test_replace.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 02da93b..8bd5078 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,7 +204,8 @@ set(mlxdata-src ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Squeeze.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Tokenize.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/ImageTransform.cpp - ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/RemoveValue.cpp) + ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/RemoveValue.cpp + ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Replace.cpp) if(AWSSDK_FOUND) list(APPEND mlxdata-src diff --git a/mlx/data/Dataset.cpp b/mlx/data/Dataset.cpp index acd1bfb..6411df9 100644 --- a/mlx/data/Dataset.cpp +++ b/mlx/data/Dataset.cpp @@ -21,6 +21,7 @@ #include "mlx/data/op/ReadFromTAR.h" #include "mlx/data/op/RemoveValue.h" #include "mlx/data/op/RenameKey.h" +#include "mlx/data/op/Replace.h" #include "mlx/data/op/SampleTransform.h" #include "mlx/data/op/SaveImage.h" #include "mlx/data/op/Shape.h" @@ -633,6 +634,31 @@ T Dataset::remove_value_if( } } +template +T Dataset::replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count) { + return transform_( + std::make_shared(key, old, replacement, count)); +} + +template +T Dataset::replace_if( + bool cond, + const std::string& key, + const std::string& old, + const std::string& replacement, + int count) { + if (cond) { + return transform_( + std::make_shared(key, old, replacement, count)); + } else { + return T(self_); + } +} + template T Dataset::rename_key(const std::string& ikey, const std::string& okey) const { diff --git a/mlx/data/Dataset.h b/mlx/data/Dataset.h index ce387ae..72e3947 100644 --- a/mlx/data/Dataset.h +++ b/mlx/data/Dataset.h @@ -314,6 +314,18 @@ class Dataset { double value, double pad) const; + T replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count = -1); + T replace_if( + bool cond, + const std::string& key, + const std::string& old, + const std::string& replacement, + int count = -1); + T rename_key(const std::string& ikey, const std::string& okey) const; T rename_key_if(bool cond, const std::string& ikey, const std::string& okey) const; diff --git a/mlx/data/core/Utils.cpp b/mlx/data/core/Utils.cpp index daadb26..30f7b8b 100644 --- a/mlx/data/core/Utils.cpp +++ b/mlx/data/core/Utils.cpp @@ -1,7 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/data/core/Utils.h" -#include namespace { @@ -57,6 +56,7 @@ void uniq_t( } } } + template void remove_t( std::shared_ptr dst, @@ -102,6 +102,65 @@ void remove_t( } } } + +template +void replace_t( + std::shared_ptr& result, + const std::shared_ptr src, + const std::shared_ptr old, + const std::shared_ptr replacement, + int count) { + int64_t src_size = src->size(); + int64_t old_size = old->size(); + int64_t replacement_size = replacement->size(); + + T* src_buffer = src->data(); + T* old_buffer = old->data(); + T* replacement_buffer = replacement->data(); + + // Calculate the result size. If this ends up being slow we can try + // a single pass algorithm that grows the buffer using realloc. We can also + // try a better search algorithm because this has a worst case complexity + // O(src_size old_size). + int64_t result_size = src_size; + int matches = 0; + if (old_size != replacement_size) { + for (int64_t i = 0; i < src_size; i++) { + if (std::equal(old_buffer, old_buffer + old_size, src_buffer + i)) { + i += old_size - 1; + result_size += replacement_size - old_size; + matches++; + } + if (matches == count) { + break; + } + } + } + + result = std::make_shared(src->type(), result_size); + T* result_buffer = result->data(); + + matches = 0; + for (int64_t i = 0, j = 0; i < src_size; i++, j++) { + if (std::equal(old_buffer, old_buffer + old_size, src_buffer + i)) { + std::copy( + replacement_buffer, + replacement_buffer + replacement_size, + result_buffer + j); + i += old_size - 1; + j += replacement_size - 1; + matches++; + } else { + result_buffer[j] = src_buffer[i]; + } + if (matches == count) { + std::copy( + src_buffer + i + 1, src_buffer + src_size, result_buffer + j + 1); + break; + } + } +} + } // namespace namespace mlx { namespace data { @@ -192,6 +251,16 @@ Sample merge_batch( return sample_batch; } +std::shared_ptr replace( + const std::shared_ptr src, + const std::shared_ptr old, + const std::shared_ptr replacement, + int count) { + std::shared_ptr result; + ARRAY_DISPATCH(src, replace_t, result, src, old, replacement, count); + return result; +} + } // namespace core } // namespace data } // namespace mlx diff --git a/mlx/data/core/Utils.h b/mlx/data/core/Utils.h index 8631bad..e3d3dca 100644 --- a/mlx/data/core/Utils.h +++ b/mlx/data/core/Utils.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/data/Array.h" #include "mlx/data/Sample.h" @@ -20,6 +20,12 @@ std::pair, std::shared_ptr> remove( double value, double pad); +std::shared_ptr replace( + const std::shared_ptr src, + const std::shared_ptr old, + const std::shared_ptr replacement, + int count); + Sample merge_batch( const std::vector& samples, const std::unordered_map& pad_values = {}, diff --git a/mlx/data/op/Replace.cpp b/mlx/data/op/Replace.cpp new file mode 100644 index 0000000..683f113 --- /dev/null +++ b/mlx/data/op/Replace.cpp @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/data/op/Replace.h" +#include "mlx/data/core/Utils.h" + +namespace mlx { +namespace data { +namespace op { + +Replace::Replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count) + : key_(key), + old_(std::make_shared(old)), + replacement_(std::make_shared(replacement)), + count_(count) {} + +Sample Replace::apply(const Sample& sample) const { + auto value = sample::check_key(sample, key_, old_->type()); + value = core::replace(value, old_, replacement_, count_); + auto new_sample = sample; + new_sample[key_] = value; + return new_sample; +} + +} // namespace op +} // namespace data +} // namespace mlx diff --git a/mlx/data/op/Replace.h b/mlx/data/op/Replace.h new file mode 100644 index 0000000..9d181d1 --- /dev/null +++ b/mlx/data/op/Replace.h @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/data/op/Op.h" + +namespace mlx { +namespace data { +namespace op { + +class Replace : public Op { + public: + Replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count); + + virtual Sample apply(const Sample& sample) const override; + + private: + std::string key_; + std::shared_ptr old_; + std::shared_ptr replacement_; + int count_; +}; + +} // namespace op +} // namespace data +} // namespace mlx diff --git a/python/src/wrap_dataset.h b/python/src/wrap_dataset.h index 6b09bb8..32a38ff 100644 --- a/python/src/wrap_dataset.h +++ b/python/src/wrap_dataset.h @@ -945,6 +945,42 @@ void mlx_data_export_dataset(py::class_& base) { py::arg("pad") = 0, "Conditional :meth:`Buffer.remove_value`."); + base.def( + "replace", + &T::replace, + py::call_guard(), + py::arg("key"), + py::arg("old"), + py::arg("replacement"), + py::arg("count") = -1, + R"pbdoc( + Replace ``old`` with ``replacement`` in the array at ``key``. + + Example: + + .. code-block:: python + + # Replace ' ' with '▁' to prepare for SPM tokenization. + dset = dset.replace("text", " ", "\u2581") + + Args: + key (str): The sample key that contains the array we are operating on. + old (str): The character sequence that we are replacing. + replacement (str): The character sequence that we are replacing with. + count (int): Perform at most ``count`` replacements. Ignore if negative. + Default: ``-1``. + )pbdoc"); + base.def( + "replace_if", + &T::replace_if, + py::call_guard(), + py::arg("cond"), + py::arg("key"), + py::arg("old"), + py::arg("replacement"), + py::arg("count") = -1, + "Conditional :meth:`Buffer.replace`."); + base.def( "rename_key", &T::rename_key, diff --git a/python/tests/test_buffer.py b/python/tests/test_buffer.py index df8fcda..1f5da17 100644 --- a/python/tests/test_buffer.py +++ b/python/tests/test_buffer.py @@ -1,11 +1,11 @@ # Copyright © 2024 Apple Inc. -from unittest import TestCase +import unittest import mlx.data as dx -class TestBuffer(TestCase): +class TestBuffer(unittest.TestCase): def test__getitem__(self): n = 5 b = dx.buffer_from_vector(list(dict(i=i) for i in range(n))) @@ -18,3 +18,7 @@ def test__getitem__(self): _ = b[n] with self.assertRaises(IndexError): _ = b[-(n + 1)] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_replace.py b/python/tests/test_replace.py new file mode 100644 index 0000000..eb49f3b --- /dev/null +++ b/python/tests/test_replace.py @@ -0,0 +1,24 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.data as dx + + +class TestReplace(unittest.TestCase): + def test_replace(self): + s = "Hello world".encode() + dset = dx.buffer_from_vector([dict(text=s)]) + + ds = dset.replace("text", "world", "everybody!") + self.assertEqual(bytes(ds[0]["text"]), b"Hello everybody!") + + ds = dset.replace("text", "l", "b") + self.assertEqual(bytes(ds[0]["text"]), b"Hebbo worbd") + + ds = dset.replace("text", "l", "b", 2) + self.assertEqual(bytes(ds[0]["text"]), b"Hebbo world") + + +if __name__ == "__main__": + unittest.main() From 3df150de170d37d6bee2dc082fdb1b4ade2602fb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 17 Apr 2024 00:39:35 -0700 Subject: [PATCH 2/7] Start standardizing the SPM tokenization --- mlx/data/Dataset.cpp | 37 +++++++++++++++++++++++ mlx/data/Dataset.h | 11 +++++++ python/mlx/data/tokenizer_helpers.py | 44 +++++++++++----------------- python/src/wrap_dataset.h | 43 +++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 27 deletions(-) diff --git a/mlx/data/Dataset.cpp b/mlx/data/Dataset.cpp index 6411df9..3807dba 100644 --- a/mlx/data/Dataset.cpp +++ b/mlx/data/Dataset.cpp @@ -850,6 +850,43 @@ T Dataset::tokenize_if( } } +template +T Dataset::tokenize_spm( + const std::string& ikey, + std::shared_ptr> trie, + bool insert_space, + const std::string& okey) const { + static const std::string space = " "; + static const std::string new_space = "▁"; + + std::string okey_ = (okey.empty()) ? ikey : okey; + + return (sample_transform_if( + !okey.empty(), + [ikey, okey](const Sample& s) { + auto new_sample = s; + new_sample[okey] = new_sample[ikey]; + return new_sample; + }) + .pad_if(insert_space, okey_, 0, 1, 0, 32) + .replace(okey_, space, new_space) + .tokenize(okey_, trie, TokenizeMode::shortest)); +} + +template +T Dataset::tokenize_spm_if( + bool cond, + const std::string& ikey, + std::shared_ptr> trie, + bool insert_space, + const std::string& okey) const { + if (cond) { + return tokenize_spm(ikey, trie, insert_space, okey); + } else { + return T(self_); + } +} + // Implement Stream template <> Stream Dataset::transform_( diff --git a/mlx/data/Dataset.h b/mlx/data/Dataset.h index 72e3947..4fecbb0 100644 --- a/mlx/data/Dataset.h +++ b/mlx/data/Dataset.h @@ -396,6 +396,17 @@ class Dataset { bool ignore_unk = false, const std::vector& trie_key_scores = {}, const std::string& okey = "") const; + T tokenize_spm( + const std::string& ikey, + std::shared_ptr> trie, + bool insert_space = true, + const std::string& okey = "") const; + T tokenize_spm_if( + bool cond, + const std::string& ikey, + std::shared_ptr> trie, + bool insert_space = true, + const std::string& okey = "") const; protected: std::shared_ptr self_; diff --git a/python/mlx/data/tokenizer_helpers.py b/python/mlx/data/tokenizer_helpers.py index 626ebc8..deadcf4 100644 --- a/python/mlx/data/tokenizer_helpers.py +++ b/python/mlx/data/tokenizer_helpers.py @@ -1,5 +1,6 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2024 Apple Inc. +import math import re from pathlib import Path @@ -53,64 +54,53 @@ def iterate_tokens(spm_file): def to_special_token(token): return b"<0x" + token.hex().encode() + b">" - sep = "\u2581".encode("utf-8") - # We parse the model in two passes. First we save the tokens in tmp_tokens - # and tmp_scores and go back and replace special tokens that already exist - # to a special token representation. This happens so we can keep the same - # ids as the original sentencepiece model. + # and go back and replace special tokens that already exist or tokens that + # have a better score to a special token representation. This happens so we + # can keep the same ids as the original sentencepiece model. tokenmap = {} tmp_tokens = [] - tmp_scores = [] - max_scores = set() + trie_key_scores = [] for token, score in iterate_tokens(spm_file): - score = -score - if re.match(b"^<.*>$", token): - # Make sure to set the max score for all special tokens - max_scores.add(len(tmp_scores)) - hex_byte = re.match(b"^<0x(..)>$", token) if hex_byte: (token,) = hex_byte.groups() token = bytes.fromhex(token.decode()) - token = token.replace(sep, b" ") - # Token already exists so we should choose either the previous one or # this one. if token in tokenmap: existing_token_id = tokenmap[token] - existing_token_score = tmp_scores[existing_token_id] + existing_token_score = trie_key_scores[existing_token_id] # We should replace that token with our token if score < existing_token_score: tmp_tokens[existing_token_id] = to_special_token(token) - max_scores.add(existing_token_id) tmp_tokens.append(token) - tmp_scores.append(score) + trie_key_scores.append(score) tokenmap[token] = len(tmp_tokens) - 1 # We should ignore this token else: tmp_tokens.append(to_special_token(token)) - tmp_scores.append(score) - max_scores.add(len(tmp_tokens) - 1) + trie_key_scores.append(score) # Token doesn't exist so add it else: tmp_tokens.append(token) - tmp_scores.append(score) + trie_key_scores.append(score) tokenmap[token] = len(tmp_tokens) - 1 - # Set the max score to duplicates - max_score = max(tmp_scores) + 1 - for token_id in max_scores: - tmp_scores[token_id] = max_score + # SPM is a BPE tokenizer so it doesn't exactly work like the MLX tokenizer. + # Favoring the shortest sequence and taking into account the scores at the + # same time yields the closest tokenization. + min_score = min(trie_key_scores) + for i in range(len(trie_key_scores)): + trie_key_scores[i] = -min_score - trie_key_scores[i] - # Build the trie and the scores + # Build the trie trie = CharTrie() - trie_key_scores = tmp_scores for token in tmp_tokens: if trie.search(token): raise RuntimeError(f"Token {token} found twice") diff --git a/python/src/wrap_dataset.h b/python/src/wrap_dataset.h index 32a38ff..1cc559a 100644 --- a/python/src/wrap_dataset.h +++ b/python/src/wrap_dataset.h @@ -1231,5 +1231,48 @@ void mlx_data_export_dataset(py::class_& base) { py::arg("trie_key_scores") = std::vector({}), py::arg("output_key") = "", "Conditional :meth:`Buffer.tokenize`."); + + base.def( + "tokenize_spm", + &T::tokenize_spm, + py::call_guard(), + py::arg("key"), + py::arg("trie"), + py::arg("insert_space") = true, + py::arg("output_key") = "", + R"pbcopy( + Preprocess the contents of the array at ``key`` and tokenize them + according to the SentencePiece tokenizer. + + This call is simply a convenience over calling pad, replace and + tokenize as follows: + + .. code-block:: python + + dset = ( + dset + .pad("text", 0, 1, 0, ord(" "), output_key="tokens") + .replace("tokens", " ", "\u2581") + .tokenize("tokens", trie) + ) + + Args: + key (str): The sample key that contains the array we are operating on. + trie (mlx.data.core.CharTrie): The trie to use for the tokenization. + insert_space (bool): Whether to prepend a space before the text. + (default: ``True``). + output_key (str): If it is not empty then write the result to this + key instead of overwriting ``key``. (default: '') + )pbcopy"); + base.def( + "tokenize_spm_if", + &T::tokenize_spm_if, + py::call_guard(), + py::arg("cond"), + py::arg("key"), + py::arg("trie"), + py::arg("insert_space") = true, + py::arg("output_key") = "", + "Conditional :meth:`Buffer.tokenize_spm`."); } } // namespace From 06add1131d0bff59cbcb4deb9b0abf97d717c991 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 17 Apr 2024 23:14:31 -0700 Subject: [PATCH 3/7] Add a seemingly working bpe tokenizer --- CMakeLists.txt | 1 + mlx/data/core/BPETokenizer.cpp | 158 +++++++++++++++++++++++++++++++++ mlx/data/core/BPETokenizer.h | 53 +++++++++++ mlx/data/core/Trie.h | 89 +++++++++++++++---- python/src/wrap_core.cpp | 81 +++++++++++++++++ 5 files changed, 363 insertions(+), 19 deletions(-) create mode 100644 mlx/data/core/BPETokenizer.cpp create mode 100644 mlx/data/core/BPETokenizer.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8bd5078..db21541 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,6 +150,7 @@ set(mlxdata-src ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/ThreadController.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/ThreadPool.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/Tokenizer.cpp + ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/BPETokenizer.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/Levenshtein.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/Utils.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/audio/Audio.cpp diff --git a/mlx/data/core/BPETokenizer.cpp b/mlx/data/core/BPETokenizer.cpp new file mode 100644 index 0000000..21cf428 --- /dev/null +++ b/mlx/data/core/BPETokenizer.cpp @@ -0,0 +1,158 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include +#include + +#include "mlx/data/core/BPETokenizer.h" +#include "mlx/data/core/Trie.h" + +namespace mlx { +namespace data { +namespace core { + +void BPEMerges::add( + const std::string& left, + const std::string& right, + int64_t token) { + auto [left_s, left_inserted] = strings_.insert(left); + auto [right_s, right_inserted] = strings_.insert(right); + + std::string_view left_v(*left_s); + std::string_view right_v(*right_s); + + auto left_it = merges_.find(left_v); + if (left_it == merges_.end()) { + merges_[left_v][right_v] = token; + } else { + auto right_it = left_it->second.find(right_v); + if (right_it == left_it->second.end()) { + left_it->second[right_v] = token; + } else { + right_it->second = std::min(token, right_it->second); + } + } +} + +std::pair BPEMerges::can_merge( + std::string_view left, + std::string_view right) const { + auto left_it = merges_.find(left); + if (left_it == merges_.end()) { + return {false, 0}; + } + auto right_it = left_it->second.find(right); + if (right_it == left_it->second.end()) { + return {false, 0}; + } + return {true, right_it->second}; +} + +BPETokenizer::BPETokenizer( + std::shared_ptr> symbols, + std::shared_ptr merges) + : symbols_(symbols), merges_(merges) {} + +std::vector BPETokenizer::tokenize(const std::string& input) const { + struct Symbol { + std::string_view value; + int64_t token; + }; + + struct Pair { + std::list::iterator left; + std::list::iterator right; + int64_t token; + std::string_view value; + + Pair( + std::list::iterator left, + std::list::iterator right, + int64_t token) + : left(left), + right(right), + token(token), + value(left->value.data(), left->value.size() + right->value.size()) {} + + bool operator<(const Pair& right) const { + return token >= right.token; + }; + }; + + // Transform the input to a sequence of basic symbols that will subsequently + // be merged. + std::list symbols; + for (auto it = input.begin(); it != input.end(); it++) { + auto [node, length] = symbols_->search_longest_prefix(it, input.end()); + if (length == 0) { + std::ostringstream msg; + msg << "BPETokenizer: Unknown symbol '" << *it << "'"; + throw std::runtime_error(msg.str()); + } + symbols.push_back(Symbol{std::string_view(&*it, length), node->id}); + it += length - 1; + } + + std::priority_queue merge_queue; + + // Initialize the merge queue + auto left = symbols.begin(); + auto right = std::next(left); + while (right != symbols.end()) { + auto [can_merge, token] = merges_->can_merge(left->value, right->value); + if (can_merge) { + merge_queue.emplace(left, right, token); + } + left++; + right++; + } + + while (!merge_queue.empty()) { + Pair pair = std::move(merge_queue.top()); + merge_queue.pop(); + + // If both left and right are valid and the value matches the pair value it + // means we can merge freely. + if (pair.left->token >= 0 && pair.right->token >= 0 && + pair.value.size() == + pair.left->value.size() + pair.right->value.size() && + pair.value.data() == pair.left->value.data()) { + pair.left->token = pair.token; + pair.left->value = pair.value; + pair.right->token = -1; + continue; + } + + // This means that the pair is invalid for some reason, so we "eat" left + // and right until they both have valid symbols. Subsequently we push the + // new pair in the merge queue. + while (pair.left->token == -1) { + pair.left = symbols.erase(pair.left); + pair.left--; + } + while (pair.right != symbols.end() && pair.right->token == -1) { + pair.right = symbols.erase(pair.right); + } + + auto [can_merge, token] = + merges_->can_merge(pair.left->value, pair.right->value); + if (can_merge) { + merge_queue.emplace(pair.left, pair.right, token); + } + } + + // Gather the final result in a vector + std::vector tokens; + for (auto& symbol : symbols) { + if (symbol.token >= 0) { + tokens.push_back(symbol.token); + } + } + + return tokens; +} + +} // namespace core +} // namespace data +} // namespace mlx diff --git a/mlx/data/core/BPETokenizer.h b/mlx/data/core/BPETokenizer.h new file mode 100644 index 0000000..413ccd0 --- /dev/null +++ b/mlx/data/core/BPETokenizer.h @@ -0,0 +1,53 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/data/core/Trie.h" + +namespace mlx { +namespace data { +namespace core { + +class BPEMerges { + public: + void add(const std::string& left, const std::string& right, int64_t token); + std::pair can_merge( + std::string_view left, + std::string_view right) const; + + template + std::pair + can_merge(iterator_type left, iterator_type middle, iterator_type end) const { + // switch to std::string_view(left, middle) when in C++20 + return can_merge( + std::string_view(&(*left), std::distance(left, middle)), + std::string_view(&(*middle), std::distance(middle, end))); + } + + private: + std::unordered_set strings_; + std::unordered_map< + std::string_view, + std::unordered_map> + merges_; +}; + +class BPETokenizer { + public: + BPETokenizer( + std::shared_ptr> symbols, + std::shared_ptr merges); + + std::vector tokenize(const std::string& input) const; + + private: + std::shared_ptr> symbols_; + std::shared_ptr merges_; +}; + +} // namespace core +} // namespace data +} // namespace mlx diff --git a/mlx/data/core/Trie.h b/mlx/data/core/Trie.h index 21e2f0f..0827c9b 100644 --- a/mlx/data/core/Trie.h +++ b/mlx/data/core/Trie.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -32,44 +33,83 @@ class Trie { nodes_.resize(1); nodes_.back().id = -1; // uid is 0 }; - const TrieNode* insert(const std::vector& key) { - TrieNode* node; - int64_t i; - std::tie(node, i) = partial_search_(key); - for (; i < key.size(); i++) { + + template + std::tuple*, int64_t> search_longest_prefix( + iterator_type it, + iterator_type end) const { + auto node = root_(); + int64_t i = 0; + auto valid_node = node; + int64_t valid_i = i; + while (it != end) { + auto kv = node->children.find(*it); + if (kv == node->children.end()) { + break; + } else { + node = kv->second; + i++; + it++; + if (node->accepts()) { + valid_node = node; + valid_i = i; + } + } + } + return std::make_tuple(valid_node, valid_i); + } + + template + const TrieNode* insert(iterator_type begin, iterator_type end) { + auto it = begin; + auto [node, i] = partial_search(it, end); + std::advance(it, i); // it += i but also supports sequential iterators + while (it != end) { nodes_.resize(nodes_.size() + 1); TrieNode* new_node = &nodes_.back(); new_node->uid = nodes_.size() - 1; new_node->id = -1; - node->children[key[i]] = new_node; + node->children[*it] = new_node; node = new_node; + it++; } if (!node->accepts()) { node->id = keys_.size(); - keys_.push_back(key); + keys_.emplace_back(begin, end); } return node; - }; - const TrieNode* search(const std::vector& key) { - auto res = partial_search_(key); - if (std::get<1>(res) != key.size()) { + } + + template + const TrieNode* search(iterator_type it, iterator_type end) { + auto [node, i] = partial_search(it, end); + if (i != std::distance(it, end) || !node->accepts()) { return nullptr; - } else { - auto node = std::get<0>(res); - return (node->accepts() ? node : nullptr); } - }; + return node; + } + + const TrieNode* insert(const std::vector& key) { + return insert(key.begin(), key.end()); + } + + const TrieNode* search(const std::vector& key) { + return search(key.begin(), key.end()); + } + const TrieNode* root() const { return &nodes_.front(); } + int64_t num_keys() const { return keys_.size(); } + const std::vector& key(int64_t id) const { return keys_.at(id); } - // helper for strings + // helpers for strings template < typename U = T, std::enable_if_t::value, char> = false> @@ -94,19 +134,30 @@ class Trie { TrieNode* root_() { return &nodes_.front(); } - std::tuple*, int64_t> partial_search_(const std::vector& key) { + + const TrieNode* root_() const { + return &nodes_.front(); + } + + template + std::tuple*, int64_t> partial_search( + iterator_type it, + iterator_type end) { auto node = root_(); int64_t i = 0; - for (; i < key.size(); i++) { - auto kv = node->children.find(key[i]); + while (it != end) { + auto kv = node->children.find(*it); if (kv == node->children.end()) { break; } else { node = kv->second; + i++; + it++; } } return std::make_tuple(node, i); } + std::deque> nodes_; std::vector> keys_; }; diff --git a/python/src/wrap_core.cpp b/python/src/wrap_core.cpp index 5207f36..fda4a74 100644 --- a/python/src/wrap_core.cpp +++ b/python/src/wrap_core.cpp @@ -8,6 +8,7 @@ #include "mlx/data/core/AWSFileFetcher.h" #endif +#include "mlx/data/core/BPETokenizer.h" #include "mlx/data/core/FileFetcher.h" #include "mlx/data/core/Graph.h" #include "mlx/data/core/Levenshtein.h" @@ -275,6 +276,86 @@ void init_mlx_data_core(py::module& m) { input (str): The input string to be tokenized. )pbcopy"); + py::class_>( + m, + "BPEMerges", + R"pbcopy( + A datastructure that holds all possible merges and allows querying + whether two strings can be merged in O(1) time. + )pbcopy") + .def(py::init<>()) + .def( + "add", + &BPEMerges::add, + py::arg("left"), + py::arg("right"), + py::arg("token"), + R"pbcopy( + Add two strings as a possible merge that results in ``token``. + + Args: + left (str): The left side to be merged. + right (str): The right side to be merged. + token (int): The resulting token. + )pbcopy") + .def( + "can_merge", + [](std::shared_ptr& merges, + const std::string& left, + const std::string& right) -> std::optional { + auto [can_merge, token] = merges->can_merge( + std::string_view(left.data(), left.size()), + std::string_view(right.data(), right.size())); + + if (!can_merge) { + return {}; + } + + return token; + }, + py::arg("left"), + py::arg("right"), + R"pbcopy( + Check if ``left`` and ``right`` can be merged to one token. + + Args: + left (str): The left side of the possible token. + right (str): The right side of the possible token. + + Returns: + The token id is returned or None if ``left`` and ``right`` + couldn't be merged. + )pbcopy"); + + py::class_>( + m, + "BPETokenizer", + R"pbcopy( + A tokenizer that uses the BPE algorithm to tokenize strings. + + Args: + symbol_trie (mlx.data.core.CharTrie): The trie containing the basic + symbols that all merges start from. + merges (mlx.data.core.BPEMerges): The datastructure holding the bpe + merges. + )pbcopy") + .def( + py::init< + std::shared_ptr>, + std::shared_ptr>(), + py::arg("symbols"), + py::arg("merges")) + .def( + "tokenize", + &BPETokenizer::tokenize, + py::arg("input"), + R"pbcopy( + Tokenize the input according to the symbols and merges. + + Args: + input (str): The input string to be tokenized. + )pbcopy"); + py::class_>( m, "FileFetcherHandle"); From 4ebe4f54c108748c266125ab4b6237e5507cda91 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 18 Apr 2024 03:25:08 -0700 Subject: [PATCH 4/7] Write BPE reading from spm model and fix BPETokenizer --- mlx/data/core/BPETokenizer.cpp | 73 ++++++++++------- mlx/data/core/Trie.h | 20 ++--- python/mlx/data/tokenizer_helpers.py | 113 +++++++++++++++++++++------ python/src/wrap_core.cpp | 11 ++- 4 files changed, 153 insertions(+), 64 deletions(-) diff --git a/mlx/data/core/BPETokenizer.cpp b/mlx/data/core/BPETokenizer.cpp index 21cf428..7360e91 100644 --- a/mlx/data/core/BPETokenizer.cpp +++ b/mlx/data/core/BPETokenizer.cpp @@ -61,14 +61,14 @@ std::vector BPETokenizer::tokenize(const std::string& input) const { }; struct Pair { - std::list::iterator left; - std::list::iterator right; + std::vector::iterator left; + std::vector::iterator right; int64_t token; std::string_view value; Pair( - std::list::iterator left, - std::list::iterator right, + std::vector::iterator left, + std::vector::iterator right, int64_t token) : left(left), right(right), @@ -82,7 +82,8 @@ std::vector BPETokenizer::tokenize(const std::string& input) const { // Transform the input to a sequence of basic symbols that will subsequently // be merged. - std::list symbols; + std::vector symbols; + symbols.reserve(input.size()); for (auto it = input.begin(); it != input.end(); it++) { auto [node, length] = symbols_->search_longest_prefix(it, input.end()); if (length == 0) { @@ -112,33 +113,51 @@ std::vector BPETokenizer::tokenize(const std::string& input) const { Pair pair = std::move(merge_queue.top()); merge_queue.pop(); - // If both left and right are valid and the value matches the pair value it - // means we can merge freely. - if (pair.left->token >= 0 && pair.right->token >= 0 && - pair.value.size() == - pair.left->value.size() + pair.right->value.size() && - pair.value.data() == pair.left->value.data()) { - pair.left->token = pair.token; - pair.left->value = pair.value; - pair.right->token = -1; + // Skip invalidated pairs + if (pair.left->token < 0 || pair.right->token < 0) { continue; } - - // This means that the pair is invalid for some reason, so we "eat" left - // and right until they both have valid symbols. Subsequently we push the - // new pair in the merge queue. - while (pair.left->token == -1) { - pair.left = symbols.erase(pair.left); - pair.left--; + if (pair.value.size() != + pair.left->value.size() + pair.right->value.size()) { + continue; } - while (pair.right != symbols.end() && pair.right->token == -1) { - pair.right = symbols.erase(pair.right); + if (pair.value.data() != pair.left->value.data()) { + continue; } - auto [can_merge, token] = - merges_->can_merge(pair.left->value, pair.right->value); - if (can_merge) { - merge_queue.emplace(pair.left, pair.right, token); + // Yay! Valid pair, let's merge into the left one. + pair.left->token = pair.token; + pair.left->value = pair.value; + + // Invalidate our neighbor which we just merged into ourselves. + pair.right->token = -1; + + // Find the first valid symbol to our left to check for a possible merge. + if (pair.left != symbols.begin()) { + auto neighbor_left = std::prev(pair.left); + while (neighbor_left != symbols.begin() && neighbor_left->token == -1) { + neighbor_left--; + } + if (neighbor_left->token != -1) { + auto [can_merge, token] = + merges_->can_merge(neighbor_left->value, pair.left->value); + if (can_merge) { + merge_queue.emplace(neighbor_left, pair.left, token); + } + } + } + + // Do the same to our right. + auto neighbor_right = std::next(pair.right); + while (neighbor_right != symbols.end() && neighbor_right->token == -1) { + neighbor_right++; + } + if (neighbor_right->token != -1) { + auto [can_merge, token] = + merges_->can_merge(pair.left->value, neighbor_right->value); + if (can_merge) { + merge_queue.emplace(pair.left, neighbor_right, token); + } } } diff --git a/mlx/data/core/Trie.h b/mlx/data/core/Trie.h index 0827c9b..5924631 100644 --- a/mlx/data/core/Trie.h +++ b/mlx/data/core/Trie.h @@ -60,7 +60,9 @@ class Trie { } template - const TrieNode* insert(iterator_type begin, iterator_type end) { + const TrieNode* + insert(iterator_type begin, iterator_type end, int64_t id = -1) { + id = (id < 0) ? keys_.size() : id; auto it = begin; auto [node, i] = partial_search(it, end); std::advance(it, i); // it += i but also supports sequential iterators @@ -74,8 +76,8 @@ class Trie { it++; } if (!node->accepts()) { - node->id = keys_.size(); - keys_.emplace_back(begin, end); + node->id = id; + keys_.emplace(id, std::vector(begin, end)); } return node; } @@ -89,8 +91,8 @@ class Trie { return node; } - const TrieNode* insert(const std::vector& key) { - return insert(key.begin(), key.end()); + const TrieNode* insert(const std::vector& key, int64_t id = -1) { + return insert(key.begin(), key.end(), id); } const TrieNode* search(const std::vector& key) { @@ -113,14 +115,14 @@ class Trie { template < typename U = T, std::enable_if_t::value, char> = false> - const TrieNode* insert(const std::string& key) { - return insert(std::vector(key.begin(), key.end())); + const TrieNode* insert(const std::string& key, int64_t id = -1) { + return insert(key.begin(), key.end(), id); }; template < typename U = T, std::enable_if_t::value, char> = false> const TrieNode* search(const std::string& key) { - return search(std::vector(key.begin(), key.end())); + return search(key.begin(), key.end()); }; template < typename U = T, @@ -159,7 +161,7 @@ class Trie { } std::deque> nodes_; - std::vector> keys_; + std::unordered_map> keys_; }; } // namespace core diff --git a/python/mlx/data/tokenizer_helpers.py b/python/mlx/data/tokenizer_helpers.py index deadcf4..5ae3eb3 100644 --- a/python/mlx/data/tokenizer_helpers.py +++ b/python/mlx/data/tokenizer_helpers.py @@ -9,7 +9,31 @@ except ImportError: SentencePieceProcessor = None -from .core import CharTrie +from .core import BPEMerges, CharTrie + + +def _iterate_spm_tokens(spm_file): + if spm_file.endswith(".model"): + if SentencePieceProcessor is None: + raise RuntimeError( + "sentencepiece must be installed to read directly from a binary model" + ) + + spm_tok = SentencePieceProcessor(spm_file) + for i in range(spm_tok.vocab_size()): + yield spm_tok.id_to_piece(i).encode("utf-8"), spm_tok.get_score(i) + + elif spm_file.endswith(".vocab"): + f = open(spm_file, "rb") + for line in f: + line = line.rstrip() + token, score = line.split(b"\t") + yield token, float(score) + + else: + raise ValueError( + f"Sentencepiece file extenstion must be in [.vocab, .model] but it was {spm_file}" + ) def read_trie_from_spm(spm_file): @@ -19,6 +43,17 @@ def read_trie_from_spm(spm_file): however if the vocabulary and the scores are exported the file can be read without installing sentencepiece. + .. note:: + + Sentencepiece models are almost always BPE models with scores being the + associated log likelihood of from a unigram language model. Using the + :class:`mlx.data.core.CharTrie` and the loaded scores will provide the + shortest possible tokenization with the highest possible log likelihood + but it can be slightly different than the BPE one. + + Use :func:`read_bpe_from_spm` to load the model to be used with a + :class:`mlx.data.core.BPETokenizer`. + Args: spm_file (str): Either a sentencepiece model file or a vocab file extracted from a sentencepiece model. @@ -28,29 +63,6 @@ def read_trie_from_spm(spm_file): corresponding weights from the SPM mdoel. """ - def iterate_tokens(spm_file): - if spm_file.endswith(".model"): - if SentencePieceProcessor is None: - raise RuntimeError( - "sentencepiece must be installed to read directly from a binary model" - ) - - spm_tok = SentencePieceProcessor(spm_file) - for i in range(spm_tok.vocab_size()): - yield spm_tok.id_to_piece(i).encode("utf-8"), spm_tok.get_score(i) - - elif spm_file.endswith(".vocab"): - f = open(spm_file, "rb") - for line in f: - line = line.rstrip() - token, score = line.split(b"\t") - yield token, float(score) - - else: - raise ValueError( - f"Sentencepiece file extenstion must be in [.vocab, .model] but it was {spm_file}" - ) - def to_special_token(token): return b"<0x" + token.hex().encode() + b">" @@ -61,7 +73,7 @@ def to_special_token(token): tokenmap = {} tmp_tokens = [] trie_key_scores = [] - for token, score in iterate_tokens(spm_file): + for token, score in _iterate_spm_tokens(spm_file): if re.match(b"^<.*>$", token): hex_byte = re.match(b"^<0x(..)>$", token) if hex_byte: @@ -109,6 +121,57 @@ def to_special_token(token): return trie, trie_key_scores +def read_bpe_from_spm(spm_file): + symbols = [] + merged = [] + tokenmap = {} + for token_id, (token, score) in enumerate(_iterate_spm_tokens(spm_file)): + if re.match(b"^<.*>$", token): + hex_byte = re.match(b"^<0x(..)>$", token) + if hex_byte: + (token,) = hex_byte.groups() + token = bytes.fromhex(token.decode()) + + if len(token) == 1 or score == 0 or len(token.decode(errors="ignore")) == 1: + symbols.append(token) + else: + merged.append(token) + + tokenmap[token] = token_id + + trie = CharTrie() + for s in symbols: + trie.insert(s, tokenmap[s]) + + merges = BPEMerges() + + def bpe(tokenmap, token, max_rank): + parts = list(token) + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = tokenmap.get((pair[0] + pair[1]).encode()) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = ( + parts[:min_idx] + + [parts[min_idx] + parts[min_idx + 1]] + + parts[min_idx + 2 :] + ) + return parts + + for t in merged: + left, right = bpe(tokenmap, t.decode(), tokenmap[t]) + merges.add(left.encode(), right.encode(), tokenmap[t]) + + return trie, merges + + def read_trie_from_vocab(vocab_file): """Read an :class:`mlx.data.core.CharTrie` from a file with one token per line. diff --git a/python/src/wrap_core.cpp b/python/src/wrap_core.cpp index fda4a74..3ca5574 100644 --- a/python/src/wrap_core.cpp +++ b/python/src/wrap_core.cpp @@ -134,21 +134,26 @@ void init_mlx_data_core(py::module& m) { .def( "insert", [](std::shared_ptr> trie, - std::variant> token) { + std::variant> token, + int64_t id) { if (std::holds_alternative(token)) { - return trie->insert(std::get(token)); + return trie->insert(std::get(token), id); } else { - return trie->insert(std::get>(token)); + return trie->insert(std::get>(token), id); } }, py::return_value_policy::reference_internal, py::arg("token"), + py::arg("id") = -1, R"pbcopy( Insert a token in the trie making a new token if it doesn't already exist. Args: token (str or list[char]): The new token to be inserted given either as a string or a list of characters. + id (int, optional): The id to assign to the new token to be + inserted. If negative then use ``num_keys()`` as default. + Default: ``-1``. )pbcopy") .def( "search", From 596ba140b81c6316287f65dab95042fefb037244 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 18 Apr 2024 03:40:57 -0700 Subject: [PATCH 5/7] Add some docs and a small test --- python/mlx/data/tokenizer_helpers.py | 19 +++++++++++++++++ python/tests/test_bpe.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 python/tests/test_bpe.py diff --git a/python/mlx/data/tokenizer_helpers.py b/python/mlx/data/tokenizer_helpers.py index 5ae3eb3..13ef89b 100644 --- a/python/mlx/data/tokenizer_helpers.py +++ b/python/mlx/data/tokenizer_helpers.py @@ -122,6 +122,25 @@ def to_special_token(token): def read_bpe_from_spm(spm_file): + """Read a sentencepiece file and decompose it to a symbol trie and BPE + merges for use with :class:`mlx.data.core.BPETokenizer`. + + Because it isn't straightforward to extract the merges from the SPM file, + we create a trie of basic symbols by considering all single unicode + character tokens as basic symbols as well as any special tokens provided. + + To extract the merges we run the BPE algorithm on the tokens in order of + probability as suggested in https://github.com/openai/tiktoken/issues/60 + for exporting an SPM model to huggingface tokenizers. + + Args: + spm_file (str): Either a sentencepiece model file or a vocab file + extracted from a sentencepiece model. + + Returns: + tuple[:class:`mlx.data.core.CharTrie`, :class:`mlx.data.core.BPEMerges`]: The + trie and the corresponding BPE merges from the SPM mdoel. + """ symbols = [] merged = [] tokenmap = {} diff --git a/python/tests/test_bpe.py b/python/tests/test_bpe.py new file mode 100644 index 0000000..f76a222 --- /dev/null +++ b/python/tests/test_bpe.py @@ -0,0 +1,31 @@ +# Copyright © 2024 Apple Inc. + +import string +import unittest + +from mlx.data.core import BPEMerges, BPETokenizer, CharTrie + + +class TestBpe(unittest.TestCase): + def test_bpe(self): + symbols = CharTrie() + symbols.insert(" ") + for s in string.ascii_letters: + symbols.insert(s) + n = symbols.num_keys() + merges = BPEMerges() + + tokenizer = BPETokenizer(symbols, merges) + + self.assertEqual(tokenizer.tokenize("abcd"), [1, 2, 3, 4]) + + merges.add("a", "b", n + 1) + self.assertEqual(tokenizer.tokenize("abcd"), [n + 1, 3, 4]) + + merges.add("c", "d", n + 2) + merges.add("b", "cd", n + 3) + self.assertEqual(tokenizer.tokenize("abcd"), [n + 1, n + 2]) + + +if __name__ == "__main__": + unittest.main() From 4e61bbb62c94c7bdfcc34361f92ee53c687eb222 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 18 Apr 2024 04:01:42 -0700 Subject: [PATCH 6/7] Add the BPE tokenize op --- mlx/data/Dataset.cpp | 32 ++++++++------------------ mlx/data/Dataset.h | 12 +++++----- mlx/data/core/BPETokenizer.cpp | 2 +- mlx/data/core/BPETokenizer.h | 2 +- mlx/data/op/Tokenize.cpp | 17 ++++++++++++++ mlx/data/op/Tokenize.h | 16 +++++++++++++ python/src/wrap_dataset.h | 42 ++++++++++++++-------------------- 7 files changed, 68 insertions(+), 55 deletions(-) diff --git a/mlx/data/Dataset.cpp b/mlx/data/Dataset.cpp index 3807dba..ff3c7df 100644 --- a/mlx/data/Dataset.cpp +++ b/mlx/data/Dataset.cpp @@ -851,37 +851,25 @@ T Dataset::tokenize_if( } template -T Dataset::tokenize_spm( +T Dataset::tokenize_bpe( const std::string& ikey, - std::shared_ptr> trie, - bool insert_space, + std::shared_ptr> symbols, + std::shared_ptr merges, const std::string& okey) const { - static const std::string space = " "; - static const std::string new_space = "▁"; - - std::string okey_ = (okey.empty()) ? ikey : okey; - - return (sample_transform_if( - !okey.empty(), - [ikey, okey](const Sample& s) { - auto new_sample = s; - new_sample[okey] = new_sample[ikey]; - return new_sample; - }) - .pad_if(insert_space, okey_, 0, 1, 0, 32) - .replace(okey_, space, new_space) - .tokenize(okey_, trie, TokenizeMode::shortest)); + return transform_( + std::make_shared(ikey, symbols, merges, okey)); } template -T Dataset::tokenize_spm_if( +T Dataset::tokenize_bpe_if( bool cond, const std::string& ikey, - std::shared_ptr> trie, - bool insert_space, + std::shared_ptr> symbols, + std::shared_ptr merges, const std::string& okey) const { if (cond) { - return tokenize_spm(ikey, trie, insert_space, okey); + return transform_( + std::make_shared(ikey, symbols, merges, okey)); } else { return T(self_); } diff --git a/mlx/data/Dataset.h b/mlx/data/Dataset.h index 4fecbb0..6f7ce1b 100644 --- a/mlx/data/Dataset.h +++ b/mlx/data/Dataset.h @@ -396,16 +396,16 @@ class Dataset { bool ignore_unk = false, const std::vector& trie_key_scores = {}, const std::string& okey = "") const; - T tokenize_spm( + T tokenize_bpe( const std::string& ikey, - std::shared_ptr> trie, - bool insert_space = true, + std::shared_ptr> symbols, + std::shared_ptr merges, const std::string& okey = "") const; - T tokenize_spm_if( + T tokenize_bpe_if( bool cond, const std::string& ikey, - std::shared_ptr> trie, - bool insert_space = true, + std::shared_ptr> symbols, + std::shared_ptr merges, const std::string& okey = "") const; protected: diff --git a/mlx/data/core/BPETokenizer.cpp b/mlx/data/core/BPETokenizer.cpp index 7360e91..5620447 100644 --- a/mlx/data/core/BPETokenizer.cpp +++ b/mlx/data/core/BPETokenizer.cpp @@ -54,7 +54,7 @@ BPETokenizer::BPETokenizer( std::shared_ptr merges) : symbols_(symbols), merges_(merges) {} -std::vector BPETokenizer::tokenize(const std::string& input) const { +std::vector BPETokenizer::tokenize(std::string_view input) const { struct Symbol { std::string_view value; int64_t token; diff --git a/mlx/data/core/BPETokenizer.h b/mlx/data/core/BPETokenizer.h index 413ccd0..343a3fc 100644 --- a/mlx/data/core/BPETokenizer.h +++ b/mlx/data/core/BPETokenizer.h @@ -41,7 +41,7 @@ class BPETokenizer { std::shared_ptr> symbols, std::shared_ptr merges); - std::vector tokenize(const std::string& input) const; + std::vector tokenize(std::string_view input) const; private: std::shared_ptr> symbols_; diff --git a/mlx/data/op/Tokenize.cpp b/mlx/data/op/Tokenize.cpp index 5255aae..26d9ccc 100644 --- a/mlx/data/op/Tokenize.cpp +++ b/mlx/data/op/Tokenize.cpp @@ -5,6 +5,7 @@ namespace mlx { namespace data { namespace op { + Tokenize::Tokenize( const std::string& ikey, std::shared_ptr> trie, @@ -15,6 +16,7 @@ Tokenize::Tokenize( : KeyTransformOp(ikey, okey), tokenizer_(trie, ignore_unk, trie_key_scores), mode_(mode) {} + std::shared_ptr Tokenize::apply_key( const std::shared_ptr& src) const { std::string str( @@ -34,6 +36,21 @@ std::shared_ptr Tokenize::apply_key( return std::make_shared(tokens); } + +BPETokenize::BPETokenize( + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey) + : KeyTransformOp(ikey, okey), tokenizer_(symbols, merges) {} + +std::shared_ptr BPETokenize::apply_key( + const std::shared_ptr& src) const { + auto tokens = tokenizer_.tokenize(std::string_view( + reinterpret_cast(src->data()), src->size() * src->itemsize())); + return std::make_shared(tokens); +} + } // namespace op } // namespace data } // namespace mlx diff --git a/mlx/data/op/Tokenize.h b/mlx/data/op/Tokenize.h index 8b14a4c..5c5346b 100644 --- a/mlx/data/op/Tokenize.h +++ b/mlx/data/op/Tokenize.h @@ -2,6 +2,7 @@ #pragma once +#include "mlx/data/core/BPETokenizer.h" #include "mlx/data/core/Tokenizer.h" #include "mlx/data/core/Trie.h" #include "mlx/data/op/KeyTransform.h" @@ -30,6 +31,21 @@ class Tokenize : public KeyTransformOp { TokenizeMode mode_; }; +class BPETokenize : public KeyTransformOp { + public: + BPETokenize( + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey = ""); + + virtual std::shared_ptr apply_key( + const std::shared_ptr& src) const override; + + private: + core::BPETokenizer tokenizer_; +}; + } // namespace op } // namespace data } // namespace mlx diff --git a/python/src/wrap_dataset.h b/python/src/wrap_dataset.h index 1cc559a..7d38e74 100644 --- a/python/src/wrap_dataset.h +++ b/python/src/wrap_dataset.h @@ -1233,46 +1233,38 @@ void mlx_data_export_dataset(py::class_& base) { "Conditional :meth:`Buffer.tokenize`."); base.def( - "tokenize_spm", - &T::tokenize_spm, + "tokenize_bpe", + &T::tokenize_bpe, py::call_guard(), py::arg("key"), - py::arg("trie"), - py::arg("insert_space") = true, + py::arg("symbols"), + py::arg("merges"), py::arg("output_key") = "", R"pbcopy( - Preprocess the contents of the array at ``key`` and tokenize them - according to the SentencePiece tokenizer. - - This call is simply a convenience over calling pad, replace and - tokenize as follows: + Tokenize the the contents of the array at ``key`` using the BPE merging + algorithm. - .. code-block:: python - - dset = ( - dset - .pad("text", 0, 1, 0, ord(" "), output_key="tokens") - .replace("tokens", " ", "\u2581") - .tokenize("tokens", trie) - ) + For instance this can be used to match the tokenization of the + Sentencepiece tokenizers. Args: key (str): The sample key that contains the array we are operating on. - trie (mlx.data.core.CharTrie): The trie to use for the tokenization. - insert_space (bool): Whether to prepend a space before the text. - (default: ``True``). + symbols (mlx.data.core.CharTrie): A trie containing the basic symbols + to use for the tokenization. + merges (mlx.data.core.BPEMerges): A datastructure containing the + merges of the basic symbols in order of priority. output_key (str): If it is not empty then write the result to this key instead of overwriting ``key``. (default: '') )pbcopy"); base.def( - "tokenize_spm_if", - &T::tokenize_spm_if, + "tokenize_bpe_if", + &T::tokenize_bpe_if, py::call_guard(), py::arg("cond"), py::arg("key"), - py::arg("trie"), - py::arg("insert_space") = true, + py::arg("symbols"), + py::arg("merges"), py::arg("output_key") = "", - "Conditional :meth:`Buffer.tokenize_spm`."); + "Conditional :meth:`Buffer.tokenize_bpe`."); } } // namespace From 17d49604cfc170504b02ce08a6ed4965de465cd2 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 18 Apr 2024 13:29:12 -0700 Subject: [PATCH 7/7] Add pointers to left and right in symbol --- mlx/data/core/BPETokenizer.cpp | 61 +++++++++++++++++----------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/mlx/data/core/BPETokenizer.cpp b/mlx/data/core/BPETokenizer.cpp index 5620447..21a0089 100644 --- a/mlx/data/core/BPETokenizer.cpp +++ b/mlx/data/core/BPETokenizer.cpp @@ -1,7 +1,5 @@ // Copyright © 2024 Apple Inc. -#include -#include #include #include @@ -57,6 +55,8 @@ BPETokenizer::BPETokenizer( std::vector BPETokenizer::tokenize(std::string_view input) const { struct Symbol { std::string_view value; + int left; + int right; int64_t token; }; @@ -91,7 +91,11 @@ std::vector BPETokenizer::tokenize(std::string_view input) const { msg << "BPETokenizer: Unknown symbol '" << *it << "'"; throw std::runtime_error(msg.str()); } - symbols.push_back(Symbol{std::string_view(&*it, length), node->id}); + symbols.push_back(Symbol{ + std::string_view(&*it, length), + static_cast(symbols.size() - 1), + static_cast(symbols.size() + 1), + node->id}); it += length - 1; } @@ -110,53 +114,50 @@ std::vector BPETokenizer::tokenize(std::string_view input) const { } while (!merge_queue.empty()) { - Pair pair = std::move(merge_queue.top()); + Pair top = std::move(merge_queue.top()); merge_queue.pop(); // Skip invalidated pairs - if (pair.left->token < 0 || pair.right->token < 0) { + if (top.left->token < 0 || top.right->token < 0) { continue; } - if (pair.value.size() != - pair.left->value.size() + pair.right->value.size()) { + if (top.value.size() != top.left->value.size() + top.right->value.size()) { continue; } - if (pair.value.data() != pair.left->value.data()) { + if (top.value.data() != top.left->value.data()) { continue; } // Yay! Valid pair, let's merge into the left one. - pair.left->token = pair.token; - pair.left->value = pair.value; + top.left->token = top.token; + top.left->value = top.value; // Invalidate our neighbor which we just merged into ourselves. - pair.right->token = -1; + top.right->token = -1; - // Find the first valid symbol to our left to check for a possible merge. - if (pair.left != symbols.begin()) { - auto neighbor_left = std::prev(pair.left); - while (neighbor_left != symbols.begin() && neighbor_left->token == -1) { - neighbor_left--; - } - if (neighbor_left->token != -1) { - auto [can_merge, token] = - merges_->can_merge(neighbor_left->value, pair.left->value); - if (can_merge) { - merge_queue.emplace(neighbor_left, pair.left, token); - } + // Adjust the pointers to neighboring symbols + top.left->right = top.right->right; + if (top.right->right < symbols.size()) { + symbols[top.right->right].left = top.right->left; + } + + // Check for a possible merge to the left. + if (top.left != symbols.begin()) { + auto neighbor = symbols.begin() + top.left->left; + auto [can_merge, token] = + merges_->can_merge(neighbor->value, top.left->value); + if (can_merge) { + merge_queue.emplace(neighbor, top.left, token); } } // Do the same to our right. - auto neighbor_right = std::next(pair.right); - while (neighbor_right != symbols.end() && neighbor_right->token == -1) { - neighbor_right++; - } - if (neighbor_right->token != -1) { + if (top.left->right < symbols.size()) { + auto neighbor = symbols.begin() + top.left->right; auto [can_merge, token] = - merges_->can_merge(pair.left->value, neighbor_right->value); + merges_->can_merge(top.left->value, neighbor->value); if (can_merge) { - merge_queue.emplace(pair.left, neighbor_right, token); + merge_queue.emplace(top.left, neighbor, token); } } }