diff --git a/CMakeLists.txt b/CMakeLists.txt index 08e2e79..629f995 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,7 @@ set(tokenizers_source_files ${CMAKE_CURRENT_SOURCE_DIR}/src/bpe_tokenizer_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/hf_tokenizer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/llama2c_tokenizer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/normalizer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/pre_tokenizer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/re2_regex.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/regex.cpp diff --git a/include/pytorch/tokenizers/bpe_tokenizer_base.h b/include/pytorch/tokenizers/bpe_tokenizer_base.h index 97542bf..5e5c05d 100644 --- a/include/pytorch/tokenizers/bpe_tokenizer_base.h +++ b/include/pytorch/tokenizers/bpe_tokenizer_base.h @@ -33,7 +33,7 @@ namespace detail { using TokenMap = StringIntegerMap<>; template -static Result buildTokenMap( +static Result build_token_map( std::vector> container) { static_assert( std::is_same_v || @@ -82,7 +82,7 @@ static Result buildTokenMap( }; template -static Result buildTokenMap( +static Result build_token_map( const TContainer& container, TTokenAccessor token_accessor, TRankAccessor rank_accessor) { @@ -103,7 +103,7 @@ static Result buildTokenMap( pairs.emplace_back(token_accessor(value), rank_accessor(value)); } - return buildTokenMap(std::move(pairs)); + return build_token_map(std::move(pairs)); } inline Result> build_special_token_regex( @@ -152,10 +152,19 @@ class BPETokenizerBase : public Tokenizer { const std::string& text, const TokenMap& allowed_special) const; - Result> byte_pair_encode_( + virtual Result> byte_pair_encode_( const std::string& piece, const TokenMap& encoder) const; + // Virtual method for BPE merging - can be overridden by derived classes + // The passed in `ranks` param for the base impl is just a regular token map + // and that the actual ranks are derived implicitly from the regular token + // map. This is the same implementation as Tiktoken. + virtual std::vector _byte_pair_merge( + const std::string& piece, + const TokenMap& ranks, + std::function func) const; + // Protected members that can be overloaded by other BPE tokenizers std::unique_ptr special_token_regex_; std::optional token_map_; diff --git a/include/pytorch/tokenizers/hf_tokenizer.h b/include/pytorch/tokenizers/hf_tokenizer.h index 54869c7..8d6bd1d 100644 --- a/include/pytorch/tokenizers/hf_tokenizer.h +++ b/include/pytorch/tokenizers/hf_tokenizer.h @@ -18,11 +18,155 @@ // Local #include #include +#include #include #include #include namespace tokenizers { +namespace detail { + +// Hash function for std::pair +struct PairHash { + std::size_t operator()(const std::pair& p) const { + return std::hash{}(p.first) ^ + (std::hash{}(p.second) << 1); + } +}; + +// Type alias for BPE merge map: (token_id_1, token_id_2) -> (rank, +// merged_token_id) +using MergeMap = std::unordered_map< + std::pair, + std::pair, + PairHash>; + +// Utility function to build merge ranks map from merge rules +template +inline Result build_merge_ranks_map( + const TMergeMap& merge_map, + const TokenMap& token_map) { + // Static assertions to verify TMergeMap has the expected key and value types + using KeyType = typename TMergeMap::key_type; + using ValueType = typename TMergeMap::mapped_type; + + static_assert( + std::is_same_v>, + "TMergeMap key type must be std::pair"); + + static_assert( + std::is_same_v>, + "TMergeMap value type must be std::pair"); + + // Use a map to handle duplicates - keep the lowest rank (highest priority) + std::unordered_map unique_merge_ranks; + + for (const auto& [pair, rank_and_id] : merge_map) { + uint64_t first_id = pair.first; + uint64_t second_id = pair.second; + uint64_t rank = rank_and_id.first; + + // Get the token strings for the pair + auto first_token = token_map.tryGetString(first_id); + auto second_token = token_map.tryGetString(second_id); + + if (first_token && second_token) { + std::string merged_token = + std::string(*first_token) + std::string(*second_token); + + // Keep the entry with the lowest rank (highest priority in BPE) + auto it = unique_merge_ranks.find(merged_token); + if (it == unique_merge_ranks.end() || rank < it->second) { + unique_merge_ranks[merged_token] = rank; + } + } + } + + // Convert to vector for buildTokenMap + std::vector> merge_rank_pairs; + merge_rank_pairs.reserve(unique_merge_ranks.size()); + + for (const auto& [token, rank] : unique_merge_ranks) { + merge_rank_pairs.emplace_back(token, rank); + } + + return build_token_map(std::move(merge_rank_pairs)); +} + +} // namespace detail + +// Simple Word structure to mimic Rust's Word behavior +struct HFWord { + std::vector tokens; + std::vector byte_lengths; + + void add(uint64_t token_id, size_t byte_len) { + tokens.push_back(token_id); + byte_lengths.push_back(byte_len); + } + + size_t size() const { + return tokens.size(); + } + + // Apply all possible merges using the merge ranks + void merge_all( + const detail::TokenMap& merge_ranks, + const detail::TokenMap& token_map) { + while (tokens.size() > 1) { + std::optional> best_merge; + + // Find the best merge (lowest rank) among adjacent token pairs + for (size_t i = 0; i < tokens.size() - 1; ++i) { + // Create the merged token string to look up its rank + auto first_token = token_map.tryGetString(tokens[i]); + auto second_token = token_map.tryGetString(tokens[i + 1]); + + if (first_token && second_token) { + std::string merged_token = + std::string(*first_token) + std::string(*second_token); + auto rank = merge_ranks.tryGetInteger(merged_token); + + if (rank && (!best_merge || *rank < best_merge->second)) { + best_merge = std::make_pair(i, static_cast(*rank)); + } + } + } + + if (!best_merge) { + break; // No more merges possible + } + + // Apply the best merge + size_t merge_idx = best_merge->first; + + // Get the merged token ID + auto first_token = token_map.tryGetString(tokens[merge_idx]); + auto second_token = token_map.tryGetString(tokens[merge_idx + 1]); + + if (first_token && second_token) { + std::string merged_token = + std::string(*first_token) + std::string(*second_token); + auto merged_id = token_map.tryGetInteger(merged_token); + + if (merged_id) { + // Replace the two tokens with the merged token + tokens[merge_idx] = *merged_id; + byte_lengths[merge_idx] += byte_lengths[merge_idx + 1]; + + // Remove the second token + tokens.erase(tokens.begin() + merge_idx + 1); + byte_lengths.erase(byte_lengths.begin() + merge_idx + 1); + } else { + break; // Merged token not found in vocabulary + } + } else { + break; // Original tokens not found in vocabulary + } + } + } +}; + class HFTokenizer : public detail::BPETokenizerBase { public: /*-- Public Interface --*/ @@ -46,8 +190,25 @@ class HFTokenizer : public detail::BPETokenizerBase { void _decode(const std::string& input, std::string& ret) const override; + Result> byte_pair_encode_( + const std::string& piece, + const detail::TokenMap& encoder) const override; + + // Override the virtual _byte_pair_merge method to use explicit merges + // specified in tokenizer.json. Different from Tiktoken (another user of + // BPETokenizerBase, but doesn't use explicit merge rules). + std::vector _byte_pair_merge( + const std::string& piece, + const detail::TokenMap& ranks, + std::function func) const override; + + Normalizer::Ptr _normalizer; PreTokenizer::Ptr _pretokenizer; TokenDecoder::Ptr _decoder; + + std::unique_ptr merge_map_; + std::optional + merge_ranks_; // Pre-computed merge ranks for BPE }; } // namespace tokenizers diff --git a/include/pytorch/tokenizers/llama2c_tokenizer.h b/include/pytorch/tokenizers/llama2c_tokenizer.h index 9be6090..91765bd 100644 --- a/include/pytorch/tokenizers/llama2c_tokenizer.h +++ b/include/pytorch/tokenizers/llama2c_tokenizer.h @@ -28,6 +28,15 @@ class Llama2cTokenizer : public Tokenizer { const override; private: + inline Error _decode_verify(uint64_t token) const { + if (!initialized_) { + return Error::Uninitialized; + } + if (token >= vocab_size_) { + return Error::OutOfRange; + } + return Error::Ok; + } std::unique_ptr vocab_ = nullptr; std::unique_ptr vocab_scores_ = nullptr; std::unique_ptr sorted_vocab_ = nullptr; diff --git a/include/pytorch/tokenizers/normalizer.h b/include/pytorch/tokenizers/normalizer.h new file mode 100644 index 0000000..5d0dda5 --- /dev/null +++ b/include/pytorch/tokenizers/normalizer.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every LICENSELINT + +#pragma once + +// Standard +#include +#include +#include +#include + +// Third Party +#include +#include + +// Local +#include + +namespace tokenizers { + +// -- Base --------------------------------------------------------------------- + +/** + * Base class for all normalizers with a single virtual method to normalize the + * input string + */ +class Normalizer { + public: + /** Shared pointer type */ + typedef std::shared_ptr Ptr; + + /** Normalize the input string + * + * This normalization may result in a string that is different from the + * original input, therefore the resulting string will be owned by the caller. + * + * NOTE: Pass by value per best practice + * https://abseil.io/docs/cpp/guides/strings#string_view + */ + virtual std::string normalize(const std::string& input) const = 0; + + virtual ~Normalizer() = default; +}; // end class Normalizer + +// -- Factory ------------------------------------------------------------------ + +// Helper macro to standardize addition of config member fields +#define NORMALIZER_CONFIG_MEMBER(type, name) \ + std::optional name; \ + NormalizerConfig& set_##name(type arg) { \ + this->name = std::move(arg); \ + return *this; \ + } + +/** + * Factory and config class for creating a new Normalizer + * + * This class is the central method for instantiating a Normalizer instance. + * It contains the common construction logic and config parameter names for all + * normalizer constructor args. + * + * NOTE: When adding a new normalizer, you must ensure its arguments are + * added to this class and it's constructor is added in the implementation! + * + * Usage Example: + * + * const auto normalizer = NormalizerConfig("Replace") + * .set_pattern(" ") + * .set_content("▁") + * .create(); + * const auto normalized = normalizer->normalize("Hello World!"); + */ +class NormalizerConfig { + public: + /*------------------------*/ + /* Public mutable members */ + /*------------------------*/ + + /** + * The Type name string matching from tokenizers + * https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/mod.rs + */ + std::string type; + + /** + * Used by: ReplaceNormalizer + */ + NORMALIZER_CONFIG_MEMBER(std::string, pattern) + + /** + * Used by: ReplaceNormalizer + */ + NORMALIZER_CONFIG_MEMBER(std::string, content) + + /** + * Used by: SequenceNormalizer + */ + NORMALIZER_CONFIG_MEMBER(std::vector, normalizers) + + /*----------------*/ + /* Public methods */ + /*----------------*/ + + /** + * Construct with the type + */ + explicit NormalizerConfig(std::string type = ""); + + /** + * Construct the normalizer instance from the member data + */ + Normalizer::Ptr create() const; + + /** + * Populate from a json config file + */ + NormalizerConfig& parse_json(const nlohmann::json& json_config); + +}; // end class NormalizerConfig + +// -- Replace ------------------------------------------------------------------ +// Used for general-purpose string replacement normalization +// CITE: +// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/replace.rs + +class ReplaceNormalizer : public Normalizer { + public: + /** + * @param pattern: The pattern to search for (can be a string or regex) + * @param content: The replacement content + */ + explicit ReplaceNormalizer( + const std::string& pattern, + const std::string& content) + : regex_(ReplaceNormalizer::create_regex_(pattern)), content_(content) {} + + /** Normalize with the stored pattern replacement */ + std::string normalize(const std::string& input) const override; + + protected: + static std::unique_ptr create_regex_(const std::string& pattern); + + std::unique_ptr regex_; + const std::string content_; + +}; // end class ReplaceNormalizer + +// -- Sequence ----------------------------------------------------------------- +// Used by tokenizers +// CITE: +// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/sequence.rs + +class SequenceNormalizer : public Normalizer { + public: + /** + * @param normalizers: The sequence of owned normalizer objects to use + */ + explicit SequenceNormalizer(std::vector normalizers); + + /** Perform normalization */ + std::string normalize(const std::string& input) const override; + + private: + const std::vector normalizers_; + +}; // end class SequenceNormalizer + +} // namespace tokenizers diff --git a/include/pytorch/tokenizers/tokenizer.h b/include/pytorch/tokenizers/tokenizer.h index 1a5f9c3..7fbfb00 100644 --- a/include/pytorch/tokenizers/tokenizer.h +++ b/include/pytorch/tokenizers/tokenizer.h @@ -32,18 +32,19 @@ class Tokenizer { virtual Error load(const std::string& tokenizer_path) = 0; + /** + * Encode the input string into a vector of token IDs. + * + * @param input The input string to tokenize + * @param bos The number of beginning-of-sequence (BOS) tokens to prepend to + * the result + * @param eos The number of end-of-sequence (EOS) tokens to append to the + * result + * @return Result containing a vector of token IDs, or an error if encoding + * fails + */ virtual Result> - encode(const std::string& input, int8_t bos, int8_t eos) const = 0; - - Error decode_verify(uint64_t token) const { - if (!initialized_) { - return Error::Uninitialized; - } - if (token >= vocab_size_) { - return Error::OutOfRange; - } - return Error::Ok; - } + encode(const std::string& input, int8_t bos = 0, int8_t eos = 0) const = 0; virtual Result decode(uint64_t prev_token, uint64_t token) const = 0; diff --git a/src/bpe_tokenizer_base.cpp b/src/bpe_tokenizer_base.cpp index 6627544..002962a 100644 --- a/src/bpe_tokenizer_base.cpp +++ b/src/bpe_tokenizer_base.cpp @@ -23,10 +23,15 @@ static uint64_t _max_size() { return std::numeric_limits::max(); } -static std::vector _byte_pair_merge( +} // namespace + +// ---- Helper utils end ------------------------------------------------------- +// ---- protected start -------------------------------------------------------- + +std::vector BPETokenizerBase::_byte_pair_merge( const std::string& piece, const TokenMap& ranks, - std::function func) { + std::function func) const { // This is a vector of (start, rank). // The rank is of the byte pair starting at position start. // The rank of the last item in the vector is not a valid value. @@ -126,10 +131,6 @@ static std::vector _byte_pair_merge( return out; } -} // namespace -// ---- Helper utils end ------------------------------------------------------- -// ---- protected start -------------------------------------------------------- - std::pair, std::string> BPETokenizerBase::split_with_allowed_special_token_( const std::string& input, @@ -193,12 +194,12 @@ Result> BPETokenizerBase::byte_pair_encode_( if (result) { return std::vector(*result); } else { - // TODO: is it possible? TK_LOG(Error, "unknown token: '%s'", piece.c_str()); return Error::EncodeFailure; } } + // Use the original _byte_pair_merge function with the proper merge ranks return _byte_pair_merge( piece, token_map, [&piece, &token_map](uint64_t start, uint64_t stop) { std::string key = piece.substr(start, stop - start); @@ -206,9 +207,8 @@ Result> BPETokenizerBase::byte_pair_encode_( if (result) { return *result; } else { - // TODO: what if key does not exist? Should we - // return `unknown`? assert(false); // ?? - return uint64_t(0); + TK_LOG(Error, "BPE merge produced unknown token: '%s'", key.c_str()); + return uint64_t(0); // Return unknown token ID instead of padding } }); } diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index fa62264..5ae6c85 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -11,9 +11,9 @@ // Standard #include +#include #include #include -#include #include #include @@ -36,7 +36,7 @@ Error HFTokenizer::load(const std::string& path) { const fs::path root(path); model_json = root / "tokenizer.json"; if (!fs::exists(model_json)) { - fprintf(stderr, "no tokenizer.json found in %s\n", path.c_str()); + TK_LOG(Info, "no tokenizer.json found in %s", path.c_str()); return Error::LoadFailure; } const auto model_config_json_path = root / "tokenizer_config.json"; @@ -48,7 +48,7 @@ Error HFTokenizer::load(const std::string& path) { // Load the tokenizer.json file std::ifstream file(model_json); if (!file) { - fprintf(stderr, "failed to open encoder file: %s\n", path.c_str()); + TK_LOG(Info, "failed to open encoder file: %s", path.c_str()); return Error::LoadFailure; } std::string contents( @@ -57,7 +57,7 @@ Error HFTokenizer::load(const std::string& path) { try { parsed_json = json::parse(contents); } catch (const json::exception& e) { - std::cerr << "Error parsing json file: " << e.what() << std::endl; + TK_LOG(Error, "Error parsing json file: %s", e.what()); return Error::LoadFailure; } @@ -65,7 +65,7 @@ Error HFTokenizer::load(const std::string& path) { try { std::vector> special_token_pairs; const auto& special_tokens = parsed_json.at("added_tokens"); - auto special_token_map = TK_UNWRAP(detail::buildTokenMap( + auto special_token_map = TK_UNWRAP(detail::build_token_map( special_tokens, [](const auto& it) -> std::string { return it.at("content"); }, [](const auto& it) -> std::uint64_t { return it.at("id"); })); @@ -77,7 +77,7 @@ Error HFTokenizer::load(const std::string& path) { // Store for future use. special_token_map_.emplace(std::move(special_token_map)); } catch (const json::out_of_range& e) { - fprintf(stderr, "Could not parse special tokens: %s\n", e.what()); + TK_LOG(Info, "Could not parse special tokens: %s", e.what()); return Error::LoadFailure; } @@ -94,25 +94,36 @@ Error HFTokenizer::load(const std::string& path) { } } - auto token_map = TK_UNWRAP(detail::buildTokenMap(std::move(token_pairs))); + auto token_map = TK_UNWRAP(detail::build_token_map(std::move(token_pairs))); token_map_.emplace(std::move(token_map)); } catch (const json::out_of_range& e) { - fprintf(stderr, "Could not parse tokens: %s\n", e.what()); + TK_LOG(Info, "Could not parse tokens: %s", e.what()); return Error::LoadFailure; } // Set the vocab size to include special tokens vocab_size_ = token_map_->size() + special_token_map_->size(); + // Set up the normalizer (optional) + try { + TK_LOG(Info, "Setting up normalizer..."); + _normalizer = + NormalizerConfig().parse_json(parsed_json.at("normalizer")).create(); + TK_LOG(Info, "Normalizer set up"); + } catch (const json::out_of_range& e) { + // No normalizer specified, this is optional + TK_LOG(Info, "No normalizer specified"); + } + // Set up the pre-tokenizer try { - std::cout << "Setting up pretokenizer..." << std::endl; + TK_LOG(Info, "Setting up pretokenizer..."); _pretokenizer = PreTokenizerConfig() .parse_json(parsed_json.at("pre_tokenizer")) .create(); - std::cout << "Pretokenizer set up" << std::endl; + TK_LOG(Info, "Pretokenizer set up"); } catch (const json::out_of_range& e) { - fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what()); + TK_LOG(Info, "Could not parse pre_tokenizer: %s", e.what()); return Error::LoadFailure; } @@ -124,14 +135,67 @@ Error HFTokenizer::load(const std::string& path) { // No decoder specified } - // TODO: Do we need to parse the merges? + // Parse the BPE merges + try { + TK_LOG(Info, "Loading BPE merges..."); + const auto& merges = parsed_json.at("/model/merges"_json_pointer); + std::vector> merge_pairs; + + for (const auto& merge : merges) { + if (merge.size() == 2) { + std::string first = merge[0]; + std::string second = merge[1]; + merge_pairs.emplace_back(first, second); + } + } + + // Build merge map: (token_id_1, token_id_2) -> (rank, merged_token_id) + merge_map_ = std::make_unique(); + for (size_t i = 0; i < merge_pairs.size(); ++i) { + const auto& [first, second] = merge_pairs[i]; + + // Get token IDs for the merge pair + auto first_id = token_map_->tryGetInteger(first); + auto second_id = token_map_->tryGetInteger(second); + + if (first_id && second_id) { + // Create merged token string + std::string merged = first + second; + auto merged_id = token_map_->tryGetInteger(merged); + + if (merged_id) { + // Store merge rule: (first_id, second_id) -> (rank, merged_id) + merge_map_->emplace( + std::make_pair(*first_id, *second_id), + std::make_pair(static_cast(i), *merged_id)); + } + } + } + + TK_LOG( + Info, + "Loaded %" PRId64 " BPE merge rules", + static_cast(merge_map_->size())); + + // Pre-compute merge ranks for efficient BPE encoding + auto merge_ranks = + TK_UNWRAP(detail::build_merge_ranks_map(*merge_map_, *token_map_)); + TK_LOG( + Info, + "Built merge ranks map with %" PRId64 " entries", + static_cast(merge_ranks.size())); + merge_ranks_.emplace(std::move(merge_ranks)); + } catch (const json::out_of_range& e) { + TK_LOG(Error, "Could not parse merges: %s", e.what()); + return Error::LoadFailure; + } // If a tokenizer config file is found, parse it to look up the eos/bos tokens if (!model_config_json.empty()) { // Load it and parse it as json std::ifstream config_file(model_config_json); if (!config_file) { - fprintf(stderr, "failed to open encoder file: %s\n", path.c_str()); + TK_LOG(Error, "failed to open encoder file: %s", path.c_str()); return Error::LoadFailure; } std::string config_contents( @@ -141,8 +205,7 @@ Error HFTokenizer::load(const std::string& path) { try { parsed_config_json = json::parse(config_contents); } catch (const json::exception& e) { - std::cerr << "Error parsing model config json json file: " << e.what() - << std::endl; + TK_LOG(Error, "Error parsing model config json json file: %s", e.what()); return Error::LoadFailure; } @@ -160,20 +223,17 @@ Error HFTokenizer::load(const std::string& path) { const auto bos_res = special_token_map_->tryGetInteger(bos_token); const auto eos_res = special_token_map_->tryGetInteger(eos_token); if (!bos_res) { - fprintf( - stderr, "BOS token %s not in special tokens\n", bos_token.c_str()); + TK_LOG(Error, "BOS token %s not in special tokens", bos_token.c_str()); return Error::LoadFailure; } if (!eos_res) { - fprintf( - stderr, "EOS token %s not in special tokens\n", eos_token.c_str()); + TK_LOG(Error, "EOS token %s not in special tokens", eos_token.c_str()); return Error::LoadFailure; } bos_tok_ = *bos_res; eos_tok_ = *eos_res; } catch (const json::out_of_range& e) { - fprintf( - stderr, "Could not eos/bos from tokenizer config: %s\n", e.what()); + TK_LOG(Error, "Could not eos/bos from tokenizer config: %s", e.what()); return Error::LoadFailure; } } @@ -249,7 +309,19 @@ Error HFTokenizer::_encode( const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const { - for (const auto& piece : _pretokenizer->pre_tokenize(input)) { + // Apply normalization first if normalizer is available + std::string normalized_input = input; + if (_normalizer) { + normalized_input = _normalizer->normalize(input); + TK_LOG( + Info, + "normalized input: '%s' -> '%s'", + input.c_str(), + normalized_input.c_str()); + } + + for (const auto& piece : _pretokenizer->pre_tokenize(normalized_input)) { + // Check if the entire word is already a token to skip merging. const auto result = token_map_->tryGetInteger(piece); if (result) { last_piece_token_len = 1; @@ -272,4 +344,101 @@ void HFTokenizer::_decode(const std::string& input, std::string& ret) const { } } +Result> HFTokenizer::byte_pair_encode_( + const std::string& piece, + const detail::TokenMap& token_map) const { + if (piece.size() == 1) { + const auto result = token_map.tryGetInteger(piece); + if (result) { + return std::vector(*result); + } else { + TK_LOG(Error, "unknown token: '%s'", piece.c_str()); + return Error::EncodeFailure; + } + } + + // Use the pre-computed merge ranks (computed once during loading) + const detail::TokenMap& merge_ranks = + merge_ranks_ ? *merge_ranks_ : token_map; + + // Use the overridden _byte_pair_merge function with the proper merge ranks + return _byte_pair_merge( + piece, merge_ranks, [&piece, &token_map](uint64_t start, uint64_t stop) { + std::string key = piece.substr(start, stop - start); + const auto result = token_map.tryGetInteger(key); + if (result) { + return *result; + } else { + TK_LOG( + Error, + "BPE merge produced unknown token: '%s', start: %" PRIu64 + ", stop: %" PRIu64, + key.c_str(), + start, + stop); + return uint64_t(0); // Return unknown token ID instead of padding + } + }); +} + +std::vector HFTokenizer::_byte_pair_merge( + const std::string& piece, + const detail::TokenMap& ranks, + std::function func) const { + // HF-specific BPE implementation that uses the Rust-style approach + // with pre-computed merge ranks + + // Start with individual characters (like Rust implementation) + HFWord word; + + // Process each UTF-8 character individually + size_t i = 0; + while (i < piece.size()) { + size_t char_start = i; + size_t char_len = 1; + + // Determine UTF-8 character length + unsigned char byte = static_cast(piece[i]); + if ((byte & 0x80) == 0) { + // ASCII character (0xxxxxxx) + char_len = 1; + } else if ((byte & 0xE0) == 0xC0) { + // 2-byte UTF-8 character (110xxxxx) + char_len = 2; + } else if ((byte & 0xF0) == 0xE0) { + // 3-byte UTF-8 character (1110xxxx) + char_len = 3; + } else if ((byte & 0xF8) == 0xF0) { + // 4-byte UTF-8 character (11110xxx) + char_len = 4; + } else { + // Invalid UTF-8 start byte, treat as single byte + char_len = 1; + } + + // Make sure we don't go beyond the string boundary + if (char_start + char_len > piece.size()) { + char_len = piece.size() - char_start; + } + + uint64_t token_id = func(char_start, char_start + char_len); + if (token_id != 0) { // Assuming 0 is padding/error token + word.add(token_id, char_len); + } else { + // Handle unknown character + TK_LOG(Error, "Unknown character in HF BPE at position %zu", char_start); + return {}; // Return empty vector to indicate failure + } + + i += char_len; + } + + // Apply BPE merges using the pre-computed merge ranks and token map + if (merge_ranks_ && token_map_) { + word.merge_all(*merge_ranks_, *token_map_); + } + + return word.tokens; +} + } // namespace tokenizers diff --git a/src/llama2c_tokenizer.cpp b/src/llama2c_tokenizer.cpp index 951ee3d..18464eb 100644 --- a/src/llama2c_tokenizer.cpp +++ b/src/llama2c_tokenizer.cpp @@ -128,7 +128,7 @@ Llama2cTokenizer::~Llama2cTokenizer() { Result Llama2cTokenizer::decode( uint64_t prev_token, uint64_t token) const { - TK_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token)); + TK_CHECK_OK_OR_RETURN_ERROR(_decode_verify(token)); const char* piece = vocab_[token]; // following BOS token, sentencepiece decoder strips any leading // whitespace diff --git a/src/normalizer.cpp b/src/normalizer.cpp new file mode 100644 index 0000000..3c9e7f9 --- /dev/null +++ b/src/normalizer.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every LICENSELINT + +// Local +#include + +// Standard +#include +#include +#include + +// Third Party +#include + +using json = nlohmann::json; + +namespace tokenizers { + +// NormalizerConfig //////////////////////////////////////////////////////////// + +NormalizerConfig::NormalizerConfig(std::string type) : type(std::move(type)) {} + +Normalizer::Ptr NormalizerConfig::create() const { + // NOTE: These types must line up with the type strings found in the + // tokenizers library + // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/mod.rs + if (type == "Replace") { + if (!pattern) { + throw std::runtime_error( + "Missing pattern for Normalizer of type Replace"); + } + if (!content) { + throw std::runtime_error( + "Missing content for Normalizer of type Replace"); + } + return Normalizer::Ptr(new ReplaceNormalizer(*pattern, *content)); + } + if (type == "Sequence") { + if (!normalizers or normalizers->empty()) { + throw std::runtime_error( + "Missing normalizers for Normalizer of type Sequence"); + } + std::vector norms; + std::transform( + normalizers->begin(), + normalizers->end(), + std::back_inserter(norms), + [](const NormalizerConfig& cfg) { return cfg.create(); }); + return Normalizer::Ptr(new SequenceNormalizer(norms)); + } + throw std::runtime_error("Unsupported Normalizer type: " + type); +} + +NormalizerConfig& NormalizerConfig::parse_json(const json& json_config) { + type = json_config.at("type"); + if (type == "Replace") { + try { + pattern = json_config.at("pattern").at("Regex"); + } catch (json::out_of_range&) { + // "Regex" is not there, check "String", which is a literal string + std::string literal = json_config.at("pattern").at("String"); + // For string patterns, escape regex special characters to treat them as + // literal strings (same as Rust's regex::escape) + pattern = IRegex::escape(literal); + } + + content = json_config.at("content"); + } else if (type == "Sequence") { + normalizers = std::vector(); + for (const auto& entry : json_config.at("normalizers")) { + normalizers->push_back(NormalizerConfig().parse_json(entry)); + } + } else { + throw std::runtime_error("Unsupported Normalizer type: " + type); + } + return *this; +} + +// ReplaceNormalizer /////////////////////////////////////////////////////////// + +std::unique_ptr ReplaceNormalizer::create_regex_( + const std::string& pattern) { + assert(!pattern.empty()); + return TK_UNWRAP_THROW(create_regex(pattern)); +} + +std::string ReplaceNormalizer::normalize(const std::string& input) const { + if (!regex_) + return input; + + std::string result = input; + auto matches = regex_->find_all(result); + + // Process matches in reverse order to avoid offset issues + for (auto it = matches.rbegin(); it != matches.rend(); ++it) { + const auto& match = *it; + result.replace(match.start, match.end - match.start, content_); + } + + return result; +} + +// SequenceNormalizer ////////////////////////////////////////////////////////// + +SequenceNormalizer::SequenceNormalizer(std::vector normalizers) + : normalizers_(std::move(normalizers)) {} + +std::string SequenceNormalizer::normalize(const std::string& input) const { + std::string result = input; + for (const auto& normalizer : normalizers_) { + result = normalizer->normalize(result); + } + return result; +} + +} // namespace tokenizers diff --git a/src/tiktoken.cpp b/src/tiktoken.cpp index c112221..e5f5e90 100644 --- a/src/tiktoken.cpp +++ b/src/tiktoken.cpp @@ -87,7 +87,7 @@ static Result _load_token_map(const std::string& path) { pairs.emplace_back(std::move(token), rank); } - return buildTokenMap(pairs); + return build_token_map(pairs); } } // namespace diff --git a/targets.bzl b/targets.bzl index 2d55359..1f3e963 100644 --- a/targets.bzl +++ b/targets.bzl @@ -136,6 +136,7 @@ def define_common_targets(): "src/hf_tokenizer.cpp", "src/pre_tokenizer.cpp", "src/token_decoder.cpp", + "src/normalizer.cpp", ], deps = [ ":regex", diff --git a/test/resources/test_hf_tokenizer.json b/test/resources/test_hf_tokenizer.json index f9bac15..bc965b0 100644 --- a/test/resources/test_hf_tokenizer.json +++ b/test/resources/test_hf_tokenizer.json @@ -34,10 +34,6 @@ "normalizer": { "type": "Sequence", "normalizers": [ - { - "type": "Prepend", - "prepend": "▁" - }, { "type": "Replace", "pattern": { diff --git a/test/targets.bzl b/test/targets.bzl index f3acaef..cc79100 100644 --- a/test/targets.bzl +++ b/test/targets.bzl @@ -138,6 +138,7 @@ def define_common_targets(): srcs = [ "test_hf_tokenizer.cpp", "test_token_decoder.cpp", + "test_normalizer.cpp", ], deps = [ "//pytorch/tokenizers:hf_tokenizer", diff --git a/test/test_normalizer.cpp b/test/test_normalizer.cpp new file mode 100644 index 0000000..9bf9142 --- /dev/null +++ b/test/test_normalizer.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every LICENSELINT + +#include +#include + +using namespace tokenizers; + +TEST(NormalizerTest, ReplaceNormalizerBasic) { + // Test basic string replacement + ReplaceNormalizer normalizer(" ", "▁"); + std::string input = "Hello World Test"; + std::string expected = "Hello▁World▁Test"; + std::string result = normalizer.normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, ReplaceNormalizerNoMatch) { + // Test when pattern doesn't match + ReplaceNormalizer normalizer("xyz", "▁"); + std::string input = "Hello World"; + std::string expected = "Hello World"; + std::string result = normalizer.normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, ReplaceNormalizerMultipleMatches) { + // Test multiple matches + ReplaceNormalizer normalizer("a", "X"); + std::string input = "banana"; + std::string expected = "bXnXnX"; + std::string result = normalizer.normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, NormalizerConfigFromJson) { + // Test JSON parsing for Replace normalizer + nlohmann::json config = { + {"type", "Replace"}, {"pattern", {{"String", " "}}}, {"content", "▁"}}; + + NormalizerConfig norm_config; + norm_config.parse_json(config); + auto normalizer = norm_config.create(); + + std::string input = "Hello World Test"; + std::string expected = "Hello▁World▁Test"; + std::string result = normalizer->normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, NormalizerConfigFromJsonRegex) { + // Test JSON parsing for Replace normalizer with regex + nlohmann::json config = { + {"type", "Replace"}, {"pattern", {{"Regex", "\\s+"}}}, {"content", "_"}}; + + NormalizerConfig norm_config; + norm_config.parse_json(config); + auto normalizer = norm_config.create(); + + std::string input = "Hello World\t\tTest"; + std::string expected = "Hello_World_Test"; + std::string result = normalizer->normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, SequenceNormalizer) { + // Test sequence of normalizers + std::vector normalizers; + normalizers.push_back(std::make_shared(" ", "▁")); + normalizers.push_back(std::make_shared("a", "X")); + + SequenceNormalizer seq_normalizer(normalizers); + + std::string input = "banana split"; + std::string expected = "bXnXnX▁split"; + std::string result = seq_normalizer.normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, SequenceNormalizerFromConfig) { + // Test sequence normalizer from config + nlohmann::json config = { + {"type", "Sequence"}, + {"normalizers", + {{{"type", "Replace"}, {"pattern", {{"String", " "}}}, {"content", "▁"}}, + {{"type", "Replace"}, + {"pattern", {{"String", "a"}}}, + {"content", "X"}}}}}; + + NormalizerConfig norm_config; + norm_config.parse_json(config); + auto normalizer = norm_config.create(); + + std::string input = "banana split"; + std::string expected = "bXnXnX▁split"; + std::string result = normalizer->normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, EmptyInput) { + // Test with empty input + ReplaceNormalizer normalizer(" ", "▁"); + std::string input = ""; + std::string expected = ""; + std::string result = normalizer.normalize(input); + EXPECT_EQ(result, expected); +} + +TEST(NormalizerTest, ConfigBuilder) { + // Test config builder pattern + auto normalizer = + NormalizerConfig("Replace").set_pattern(" ").set_content("▁").create(); + + std::string input = "Hello World"; + std::string expected = "Hello▁World"; + std::string result = normalizer->normalize(input); + EXPECT_EQ(result, expected); +}