Skip to content

Commit b2945b5

Browse files
authored
Add token_to_id, id_to_token, and get_added_tokens_decoder methods to Tokenizer class (#13)
1 parent 55bf5e4 commit b2945b5

File tree

7 files changed

+245
-7
lines changed

7 files changed

+245
-7
lines changed

src/core/PreTokenizer.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,21 @@ import type { PreTokenizeTextOptions } from "@static/tokenizer";
66
* A callable class representing a pre-tokenizer used in tokenization. Subclasses
77
* should implement the `pre_tokenize_text` method to define the specific pre-tokenization logic.
88
*/
9-
abstract class PreTokenizer extends Callable<[string | string[], any?], string[]> {
9+
abstract class PreTokenizer extends Callable<
10+
[string | string[], any?],
11+
string[]
12+
> {
1013
/**
1114
* Method that should be implemented by subclasses to define the specific pre-tokenization logic.
1215
*
1316
* @param text The text to pre-tokenize.
1417
* @param options Additional options for the pre-tokenization logic.
1518
* @returns The pre-tokenized text.
1619
*/
17-
abstract pre_tokenize_text(text: string, options?: PreTokenizeTextOptions): string[];
20+
abstract pre_tokenize_text(
21+
text: string,
22+
options?: PreTokenizeTextOptions,
23+
): string[];
1824

1925
/**
2026
* Tokenizes the given text into pre-tokens.

src/core/Tokenizer.ts

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ import type PreTokenizer from "./PreTokenizer";
1919
import type TokenizerModel from "./TokenizerModel";
2020
import type PostProcessor from "./PostProcessor";
2121
import type Decoder from "./Decoder";
22-
import type { TokenConfig, TokenizerConfig, TokenizerJSON } from "@static/tokenizer";
22+
import type {
23+
TokenConfig,
24+
TokenizerConfig,
25+
TokenizerJSON,
26+
} from "@static/tokenizer";
2327

2428
interface EncodeOptions {
2529
text_pair?: string | null;
@@ -292,6 +296,36 @@ class Tokenizer {
292296
? this.post_processor(tokens1, tokens2, add_special_tokens)
293297
: { tokens: merge_arrays(tokens1 ?? [], tokens2 ?? []) };
294298
}
299+
300+
/**
301+
* Converts a token string to its corresponding token ID.
302+
* @param token The token string to convert.
303+
* @returns The token ID, or undefined if the token is not in the vocabulary.
304+
*/
305+
public token_to_id(token: string): number | undefined {
306+
return this.model.tokens_to_ids.get(token);
307+
}
308+
309+
/**
310+
* Converts a token ID to its corresponding token string.
311+
* @param id The token ID to convert.
312+
* @returns The token string, or undefined if the ID is not in the vocabulary.
313+
*/
314+
public id_to_token(id: number): string | undefined {
315+
return this.model.vocab[id];
316+
}
317+
318+
/**
319+
* Returns a mapping of token IDs to AddedToken objects for all added tokens.
320+
* @returns A Map where keys are token IDs and values are AddedToken objects.
321+
*/
322+
public get_added_tokens_decoder(): Map<number, AddedToken> {
323+
const decoder = new Map<number, AddedToken>();
324+
for (const token of this.added_tokens) {
325+
decoder.set(token.id, token);
326+
}
327+
return decoder;
328+
}
295329
}
296330

297331
export default Tokenizer;

src/core/decoder/create_decoder.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import ByteLevel from "./ByteLevel";
32
import WordPiece from "./WordPiece";
43
import Metaspace from "./Metaspace";

src/core/normalizer/create_normalizer.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import BertNormalizer from "./BertNormalizer";
32
import Precompiled from "./Precompiled";
43
import Sequence from "./Sequence";

src/core/postProcessor/create_post_processor.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import TemplateProcessing from "./TemplateProcessing";
32
import ByteLevel from "./ByteLevel";
43
import BertProcessing from "./BertProcessing";

src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export { default as Tokenizer } from "./core/Tokenizer";
2+
export { default as AddedToken } from "./core/AddedToken";
23
export type { Encoding } from "./static/types";
34

45
// Export all components

tests/tokenizers.test.ts

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import fetchConfigById from "./utils/fetchConfigById";
2-
import { Tokenizer } from "../src";
2+
import { Tokenizer, AddedToken } from "../src";
33
import collectTests from "./utils/collectTests";
44

55
const TOKENIZER_TESTS = await collectTests();
@@ -43,3 +43,203 @@ describe("Tokenizers (model-specific)", () => {
4343
});
4444
}
4545
});
46+
47+
describe("Tokenizer methods", () => {
48+
// Create a simple BPE tokenizer for testing
49+
// Vocab size: 10 tokens
50+
// - 3 special tokens: <s>, </s>, <pad>
51+
// - 1 unk token: <unk>
52+
// - 5 regular tokens: a, b, c, ab, bc
53+
// - 1 non-special added token: "<added>"
54+
const unk_token = "<unk>";
55+
const bos_token = "<s>";
56+
const eos_token = "</s>";
57+
const pad_token = "<pad>";
58+
const added_token = "<added>";
59+
60+
const added_tokens = [
61+
new AddedToken({
62+
id: 0,
63+
content: unk_token,
64+
special: true,
65+
}),
66+
new AddedToken({
67+
id: 1,
68+
content: bos_token,
69+
special: true,
70+
}),
71+
new AddedToken({
72+
id: 2,
73+
content: eos_token,
74+
special: true,
75+
}),
76+
new AddedToken({
77+
id: 3,
78+
content: pad_token,
79+
special: true,
80+
}),
81+
new AddedToken({
82+
id: 9,
83+
content: added_token,
84+
special: false, // regular added token
85+
}),
86+
];
87+
88+
const tokenizerJson = {
89+
version: "1.0",
90+
truncation: null,
91+
padding: null,
92+
added_tokens,
93+
normalizer: null,
94+
pre_tokenizer: null,
95+
post_processor: null,
96+
decoder: null,
97+
model: {
98+
type: "BPE",
99+
dropout: null,
100+
unk_token,
101+
continuing_subword_prefix: null,
102+
end_of_word_suffix: null,
103+
fuse_unk: false,
104+
byte_fallback: false,
105+
ignore_merges: false,
106+
vocab: {
107+
[unk_token]: 0,
108+
[bos_token]: 1,
109+
[eos_token]: 2,
110+
[pad_token]: 3,
111+
a: 4,
112+
b: 5,
113+
c: 6,
114+
ab: 7,
115+
bc: 8,
116+
},
117+
merges: [
118+
["a", "b"],
119+
["b", "c"],
120+
],
121+
},
122+
} as any;
123+
124+
const tokenizerConfig = {
125+
add_bos_token: false,
126+
add_prefix_space: false,
127+
added_tokens_decoder: Object.fromEntries(added_tokens.map((token) => [String(token.id), { id: token.id, content: token.content, special: token.special }])),
128+
bos_token,
129+
clean_up_tokenization_spaces: false,
130+
eos_token,
131+
legacy: true,
132+
model_max_length: 1000000000000000,
133+
pad_token,
134+
sp_model_kwargs: {},
135+
spaces_between_special_tokens: false,
136+
tokenizer_class: "LlamaTokenizer",
137+
unk_token,
138+
};
139+
140+
let tokenizer: Tokenizer;
141+
142+
beforeAll(() => {
143+
tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
144+
});
145+
146+
describe("token_to_id", () => {
147+
test("should return correct ID for regular token", () => {
148+
expect(tokenizer.token_to_id("a")).toBe(4);
149+
expect(tokenizer.token_to_id("b")).toBe(5);
150+
expect(tokenizer.token_to_id("c")).toBe(6);
151+
});
152+
153+
test("should return correct ID for merged token", () => {
154+
expect(tokenizer.token_to_id("ab")).toBe(7);
155+
expect(tokenizer.token_to_id("bc")).toBe(8);
156+
});
157+
158+
test("should return correct ID for special tokens", () => {
159+
expect(tokenizer.token_to_id(unk_token)).toBe(0);
160+
expect(tokenizer.token_to_id(bos_token)).toBe(1);
161+
expect(tokenizer.token_to_id(eos_token)).toBe(2);
162+
expect(tokenizer.token_to_id(pad_token)).toBe(3);
163+
expect(tokenizer.token_to_id(added_token)).toBe(9);
164+
});
165+
166+
test("should return undefined for non-existing token", () => {
167+
expect(tokenizer.token_to_id("xyz")).toBeUndefined();
168+
});
169+
});
170+
171+
describe("id_to_token", () => {
172+
test("should return correct token for regular token ID", () => {
173+
expect(tokenizer.id_to_token(4)).toBe("a");
174+
expect(tokenizer.id_to_token(5)).toBe("b");
175+
expect(tokenizer.id_to_token(6)).toBe("c");
176+
});
177+
178+
test("should return correct token for merged token ID", () => {
179+
expect(tokenizer.id_to_token(7)).toBe("ab");
180+
expect(tokenizer.id_to_token(8)).toBe("bc");
181+
});
182+
183+
test("should return correct token for special/added token ID", () => {
184+
expect(tokenizer.id_to_token(0)).toBe(unk_token);
185+
expect(tokenizer.id_to_token(1)).toBe(bos_token);
186+
expect(tokenizer.id_to_token(2)).toBe(eos_token);
187+
expect(tokenizer.id_to_token(3)).toBe(pad_token);
188+
expect(tokenizer.id_to_token(9)).toBe(added_token);
189+
});
190+
191+
test("should return undefined for non-existing ID", () => {
192+
expect(tokenizer.id_to_token(999)).toBeUndefined();
193+
});
194+
});
195+
196+
describe("get_added_tokens_decoder", () => {
197+
test("should return a Map", () => {
198+
const decoder = tokenizer.get_added_tokens_decoder();
199+
expect(decoder).toBeInstanceOf(Map);
200+
});
201+
202+
test("should contain all special tokens", () => {
203+
const decoder = tokenizer.get_added_tokens_decoder();
204+
expect(decoder.size).toBe(5);
205+
expect(decoder.has(0)).toBe(true);
206+
expect(decoder.has(1)).toBe(true);
207+
expect(decoder.has(2)).toBe(true);
208+
expect(decoder.has(3)).toBe(true);
209+
expect(decoder.has(9)).toBe(true);
210+
});
211+
212+
test("should return AddedToken objects with correct properties", () => {
213+
const decoder = tokenizer.get_added_tokens_decoder();
214+
const unkToken = decoder.get(0);
215+
expect(unkToken).toBeDefined();
216+
expect(unkToken?.content).toBe(unk_token);
217+
expect(unkToken?.special).toBe(true);
218+
expect(unkToken).toBeInstanceOf(AddedToken);
219+
220+
const bosToken = decoder.get(1);
221+
expect(bosToken?.content).toBe(bos_token);
222+
expect(bosToken?.special).toBe(true);
223+
});
224+
225+
test("should not contain regular tokens", () => {
226+
const decoder = tokenizer.get_added_tokens_decoder();
227+
expect(decoder.has(4)).toBe(false);
228+
expect(decoder.has(5)).toBe(false);
229+
expect(decoder.has(6)).toBe(false);
230+
});
231+
});
232+
233+
describe("roundtrip conversions", () => {
234+
test("token_to_id and id_to_token should be inverse operations", () => {
235+
const tokens = [unk_token, bos_token, eos_token, pad_token, "a", "b", "c", "ab", "bc", added_token];
236+
237+
for (const token of tokens) {
238+
const id = tokenizer.token_to_id(token);
239+
expect(id).toBeDefined();
240+
const tokenBack = tokenizer.id_to_token(id!);
241+
expect(tokenBack).toBe(token);
242+
}
243+
});
244+
});
245+
});

0 commit comments

Comments
 (0)