Skip to content

Commit

Permalink
Work with T5's tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Oct 4, 2024
1 parent f87cda2 commit ec435a7
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 11 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ jobs:
huggingface download --silent Qwen/Qwen2-0.5B
yarn tsx src/generate.ts --max-tokens=128 Qwen2-0.5B
huggingface download --silent google-t5/t5-small --hf
huggingface download --silent Xenova/t5-small --hf # fix tokenizer
yarn tsx src/generate.ts --temperature=0 t5-small 'translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast.'
publish:
if: startsWith(github.ref, 'refs/tags/')
needs: [build]
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"typescript": "5.6.2"
},
"dependencies": {
"@frost-beta/mlx": "0.0.22",
"@frost-beta/mlx": "0.0.23",
"@lenml/tokenizers": "1.1.2",
"sharp": "0.33.5"
}
Expand Down
4 changes: 3 additions & 1 deletion src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ export async function* step(promptEmbeds: mx.array,
// The generation is aborted.
if (signal?.aborted)
break;
yield nextToken;
// Do not yield token if it is the decoderStartToken.
if (!(model.hasEncoder && nextToken == model.decoderStartToken))
yield nextToken;
// Forward the token to model and free intermediate tensors.
[ nextToken ] = await mx.tidy(async () => {
const logits = model.forward(mx.array([ [ nextToken ] ], mx.int32), memory, cache);
Expand Down
3 changes: 0 additions & 3 deletions src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ main(argv[0], argv[1]);
async function main(dir: string, prompt?: string) {
const llm = await loadLLM(dir);
const promptEmbeds = await llm.encode(prompt);

if (prompt)
process.stdout.write(prompt);
for await (const text of llm.generate(promptEmbeds, options))
process.stdout.write(text);
process.stdout.write('\n');
Expand Down
32 changes: 26 additions & 6 deletions src/tokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ export class Tokenizer {
// Create tokenizer.
const tokenizerJSON = readJsonSync(`${dir}/tokenizer.json`);
this.tokenizer = TokenizerLoader.fromPreTrained({tokenizerJSON, tokenizerConfig});
// Remove the strip decoder as it breaks streaming.
const {decoders} = this.tokenizer.decoder as any;
if (decoders?.at(-1)?.config?.type == 'Strip') {
this.trimLeft = true;
decoders.pop();
}
this.removeWhiteSpaceStripper();
// Get EOS token.
const {tokens_to_ids} = this.tokenizer.model;
this.eosToken = tokens_to_ids.get(this.tokenizer.getToken('eos_token'));
Expand Down Expand Up @@ -85,4 +80,29 @@ export class Tokenizer {
return tokens.slice(this.systemPromptLength);
return tokens;
}

// Many tokenizer decoders strips the heading whitespaces for the output,
// which breaks our streaming output as whitespaces between tokens are lost.
private removeWhiteSpaceStripper() {
const decoder = this.tokenizer.decoder as any;
if (decoder.constructor.name == 'MetaspaceDecoder') {
if (decoder.prepend_scheme == 'always' && decoder.split == true) {
// Work around a bug of transformers.js:
// https://github.com/xenova/transformers.js/issues/959
decoder.addPrefixSpace = true;
delete decoder.prepend_scheme;
delete decoder.split;
}
if (decoder.addPrefixSpace) {
this.trimLeft = true;
decoder.addPrefixSpace = false;
return;
}
}
if (decoder.decoders?.at(-1)?.config?.type == 'Strip') {
this.trimLeft = true;
decoder.decoders.pop();
return;
}
}
}

0 comments on commit ec435a7

Please sign in to comment.