Skip to content

Commit 2d76b21

Browse files
committed
Fix tokenizer encoding problem for multibyte strings
1 parent ea95a1d commit 2d76b21

File tree

2 files changed

+61
-94
lines changed

2 files changed

+61
-94
lines changed

src/llama-vocab.cpp

Lines changed: 60 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3503,8 +3503,14 @@ void llama_vocab_plamo2::build(const std::vector<vocab_entry> & vocab) {
35033503

35043504
// Add token and all its suffixes to suffix_to_score
35053505
suffix_to_score[entry.text] = entry.score;
3506-
for (size_t i = 1; i < entry.text.length(); ++i) {
3507-
std::string suffix = entry.text.substr(i);
3506+
3507+
// Extract suffixes character by character (UTF-8 aware)
3508+
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(entry.text);
3509+
for (size_t i = 1; i < cpts.size(); ++i) {
3510+
std::string suffix;
3511+
for (size_t j = i; j < cpts.size(); ++j) {
3512+
suffix += unicode_cpt_to_utf8(cpts[j]);
3513+
}
35083514
if (suffix_to_score.find(suffix) == suffix_to_score.end()) {
35093515
suffix_to_score[suffix] = std::numeric_limits<float>::quiet_NaN();
35103516
}
@@ -3535,26 +3541,34 @@ void llama_vocab_plamo2::build(const std::vector<vocab_entry> & vocab) {
35353541
std::unordered_map<std::string, int32_t> suffix_to_id;
35363542
int32_t num_pieces = 0;
35373543

3538-
for (const auto & s : suffixes) {
3539-
suffix_to_id[s] = num_pieces;
3540-
if (!s.empty()) {
3541-
// Convert first character to Unicode code point
3542-
std::vector<int32_t> unicode_chars = utf8_to_unicode(s);
3543-
if (!unicode_chars.empty()) {
3544-
int64_t piece_code = (static_cast<int64_t>(unicode_chars[0]) << 32) | suffix_to_id[s.substr(1)];
3545-
to_suffix_id_[piece_code] = num_pieces;
3544+
for (const auto & suffix : suffixes) {
3545+
suffix_to_id[suffix] = num_pieces;
3546+
if (!suffix.empty()) {
3547+
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
3548+
3549+
std::string remaining;
3550+
for (size_t i = 1; i < cpts.size(); ++i) {
3551+
remaining += unicode_cpt_to_utf8(cpts[i]);
35463552
}
3547-
}
35483553

3549-
// Count number of pieces for this suffix
3550-
int32_t pieces_for_suffix = 1; // sentinel row
3551-
for (size_t i = 1; i <= s.length(); ++i) {
3552-
std::string prefix = s.substr(0, i);
3553-
if (suffix_to_score.find(prefix) != suffix_to_score.end()) {
3554-
pieces_for_suffix++;
3554+
int64_t piece_code = (static_cast<int64_t>(cpts[0]) << 32) | suffix_to_id[remaining];
3555+
to_suffix_id_[piece_code] = num_pieces;
3556+
3557+
// Count number of pieces for this suffix
3558+
int32_t pieces_for_suffix = 1; // sentinel row
3559+
for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
3560+
std::string piece;
3561+
for (int32_t i = 0; i < piece_length; ++i) {
3562+
piece += unicode_cpt_to_utf8(cpts[i]);
3563+
}
3564+
if (suffix_to_score.find(piece) != suffix_to_score.end()) {
3565+
pieces_for_suffix++;
3566+
}
35553567
}
3568+
num_pieces += pieces_for_suffix;
3569+
} else {
3570+
num_pieces++; // Empty suffix contributes one piece (sentinel row)
35563571
}
3557-
num_pieces += pieces_for_suffix;
35583572
}
35593573

35603574
// Build flattened table
@@ -3563,8 +3577,13 @@ void llama_vocab_plamo2::build(const std::vector<vocab_entry> & vocab) {
35633577

35643578
for (const auto & suffix : suffixes) {
35653579
// Add all prefixes of the suffix to the table (in decreasing order of length)
3566-
for (int32_t piece_length = static_cast<int32_t>(suffix.length()); piece_length > 0; --piece_length) {
3567-
std::string piece = suffix.substr(0, piece_length);
3580+
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
3581+
for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
3582+
std::string piece;
3583+
for (int32_t i = 0; i < piece_length; ++i) {
3584+
piece += unicode_cpt_to_utf8(cpts[i]);
3585+
}
3586+
35683587
auto score_it = suffix_to_score.find(piece);
35693588
if (score_it == suffix_to_score.end()) {
35703589
continue;
@@ -3590,51 +3609,7 @@ void llama_vocab_plamo2::build(const std::vector<vocab_entry> & vocab) {
35903609
}
35913610
}
35923611

3593-
std::vector<int32_t> llama_vocab_plamo2::utf8_to_unicode(const std::string & text) const {
3594-
std::vector<int32_t> result;
3595-
const char * ptr = text.c_str();
3596-
const char * end = ptr + text.length();
3597-
3598-
while (ptr < end) {
3599-
int32_t codepoint = 0;
3600-
int bytes_read = 0;
3601-
3602-
if ((*ptr & 0x80) == 0) {
3603-
// ASCII
3604-
codepoint = *ptr;
3605-
bytes_read = 1;
3606-
} else if ((*ptr & 0xE0) == 0xC0) {
3607-
// 2-byte UTF-8
3608-
codepoint = (*ptr & 0x1F) << 6;
3609-
codepoint |= (*(ptr + 1) & 0x3F);
3610-
bytes_read = 2;
3611-
} else if ((*ptr & 0xF0) == 0xE0) {
3612-
// 3-byte UTF-8
3613-
codepoint = (*ptr & 0x0F) << 12;
3614-
codepoint |= (*(ptr + 1) & 0x3F) << 6;
3615-
codepoint |= (*(ptr + 2) & 0x3F);
3616-
bytes_read = 3;
3617-
} else if ((*ptr & 0xF8) == 0xF0) {
3618-
// 4-byte UTF-8
3619-
codepoint = (*ptr & 0x07) << 18;
3620-
codepoint |= (*(ptr + 1) & 0x3F) << 12;
3621-
codepoint |= (*(ptr + 2) & 0x3F) << 6;
3622-
codepoint |= (*(ptr + 3) & 0x3F);
3623-
bytes_read = 4;
3624-
} else {
3625-
// Invalid UTF-8, skip this byte
3626-
ptr++;
3627-
continue;
3628-
}
3629-
3630-
result.push_back(codepoint);
3631-
ptr += bytes_read;
3632-
}
3633-
3634-
return result;
3635-
}
3636-
3637-
std::vector<llama_token> llama_vocab_plamo2::encode_unicode(const std::vector<int32_t> & unicode_data) const {
3612+
std::vector<llama_token> llama_vocab_plamo2::encode_unicode(const std::vector<uint32_t> & unicode_data) const {
36383613
if (unicode_data.empty()) {
36393614
return {};
36403615
}
@@ -3652,7 +3627,7 @@ std::vector<llama_token> llama_vocab_plamo2::encode_unicode(const std::vector<in
36523627

36533628
// Process from end to beginning
36543629
for (int i = static_cast<int>(data_len) - 1; i >= 0; --i) {
3655-
int32_t c = unicode_data[i];
3630+
uint32_t c = unicode_data[i];
36563631

36573632
// Find next suffix ID
36583633
for (size_t p = suffix_id; p < table_.size(); ++p) {
@@ -3701,40 +3676,38 @@ std::vector<llama_token> llama_vocab_plamo2::encode_unicode(const std::vector<in
37013676
token_ids.push_back(path[pos][PATH_TOKEN_ID]);
37023677
} else {
37033678
// Fall back to byte tokens
3704-
int32_t c = unicode_data[pos];
3679+
uint32_t c = unicode_data[pos];
37053680
int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
37063681

3707-
for (int j = 0; j < s; ++j) {
3708-
uint8_t b = (s == 1) ? c :
3709-
(j == 0) ? (0xF00 >> s) & 0xFF :
3710-
0x80 | ((c >> ((s - j - 1) * 6)) & 0x3F);
3711-
token_ids.push_back(bytes_[b]);
3682+
for (int i = 0; i < s; ++i) {
3683+
uint8_t b;
3684+
if (s == 1) {
3685+
b = c;
3686+
} else {
3687+
if (i == 0) {
3688+
b = (0xF00 >> s) & 0xFF;
3689+
} else {
3690+
b = 0x80;
3691+
}
3692+
}
3693+
token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]);
37123694
}
37133695
}
37143696

3697+
assert(path[pos][PATH_TOKEN_LENGTH] > 0);
37153698
pos += path[pos][PATH_TOKEN_LENGTH];
37163699
}
37173700

37183701
return token_ids;
37193702
}
37203703

37213704
std::vector<llama_token> llama_vocab_plamo2::encode(const std::string & text) const {
3722-
std::vector<int32_t> unicode_data = utf8_to_unicode(text);
3723-
return encode_unicode(unicode_data);
3724-
}
3725-
3726-
std::vector<std::string> llama_vocab_plamo2::encode_as_tokens(const std::string & text) const {
3727-
std::vector<llama_token> token_ids = encode(text);
3728-
std::vector<std::string> result;
3729-
result.reserve(token_ids.size());
3730-
3731-
for (llama_token id : token_ids) {
3732-
if (id >= 0 && id < static_cast<llama_token>(tokens_.size())) {
3733-
result.push_back(tokens_[id]);
3734-
}
3705+
std::vector<uint32_t> unicode_data = unicode_cpts_from_utf8(text);
3706+
// Skip the first code point if it is a BOM (Byte Order Mark)
3707+
if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) {
3708+
unicode_data.erase(unicode_data.begin());
37353709
}
3736-
3737-
return result;
3710+
return encode_unicode(unicode_data);
37383711
}
37393712

37403713
const std::string & llama_vocab_plamo2::get_token_text(llama_token id) const {

src/llama-vocab.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,6 @@ class llama_vocab_plamo2 {
165165
// Encode text to token IDs
166166
std::vector<llama_token> encode(const std::string & text) const;
167167

168-
// Encode text to token strings
169-
std::vector<std::string> encode_as_tokens(const std::string & text) const;
170-
171168
// Get token text by ID
172169
const std::string & get_token_text(llama_token id) const;
173170

@@ -191,8 +188,5 @@ class llama_vocab_plamo2 {
191188
std::vector<std::vector<int32_t>> table_;
192189

193190
// Helper functions
194-
void build_suffix_map(const std::vector<vocab_entry> & vocab);
195-
void build_trie_table(const std::vector<vocab_entry> & vocab);
196-
std::vector<int32_t> utf8_to_unicode(const std::string & text) const;
197-
std::vector<llama_token> encode_unicode(const std::vector<int32_t> & unicode_data) const;
191+
std::vector<llama_token> encode_unicode(const std::vector<uint32_t> & unicode_data) const;
198192
};

0 commit comments

Comments
 (0)