From d1b98cce35c9542cf179a8dd58775db9c0eae00a Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 4 Oct 2024 18:50:21 +0900 Subject: [PATCH] Work with t5's tokenizer --- .github/workflows/build.yml | 4 ++++ src/base.ts | 4 +++- src/generate.ts | 3 --- src/tokenizer.ts | 32 ++++++++++++++++++++++++++------ 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 775b9c0..091d1a7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,6 +27,10 @@ jobs: huggingface download --silent Qwen/Qwen2-0.5B yarn tsx src/generate.ts --max-tokens=128 Qwen2-0.5B + huggingface download --silent google-t5/t5-small --hf + huggingface download --silent Xenova/t5-small --hf # fix tokenizer + yarn tsx src/generate.ts --temperature=0 t5-small 'translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast.' + publish: if: startsWith(github.ref, 'refs/tags/') needs: [build] diff --git a/src/base.ts b/src/base.ts index b649e41..c3ab4e4 100644 --- a/src/base.ts +++ b/src/base.ts @@ -261,7 +261,9 @@ export async function* step(promptEmbeds: mx.array, // The generation is aborted. if (signal?.aborted) break; - yield nextToken; + // Do not yield token if it is the decoderStartToken. + if (!(model.hasEncoder && nextToken == model.decoderStartToken)) + yield nextToken; // Forward the token to model and free intermediate tensors. [ nextToken ] = await mx.tidy(async () => { const logits = model.forward(mx.array([ [ nextToken ] ], mx.int32), memory, cache); diff --git a/src/generate.ts b/src/generate.ts index e7454ba..b432b34 100755 --- a/src/generate.ts +++ b/src/generate.ts @@ -14,9 +14,6 @@ main(argv[0], argv[1]); async function main(dir: string, prompt?: string) { const llm = await loadLLM(dir); const promptEmbeds = await llm.encode(prompt); - - if (prompt) - process.stdout.write(prompt); for await (const text of llm.generate(promptEmbeds, options)) process.stdout.write(text); process.stdout.write('\n'); diff --git a/src/tokenizer.ts b/src/tokenizer.ts index 45a1fe5..fe1d010 100644 --- a/src/tokenizer.ts +++ b/src/tokenizer.ts @@ -43,12 +43,7 @@ export class Tokenizer { // Create tokenizer. const tokenizerJSON = readJsonSync(`${dir}/tokenizer.json`); this.tokenizer = TokenizerLoader.fromPreTrained({tokenizerJSON, tokenizerConfig}); - // Remove the strip decoder as it breaks streaming. - const {decoders} = this.tokenizer.decoder as any; - if (decoders?.at(-1)?.config?.type == 'Strip') { - this.trimLeft = true; - decoders.pop(); - } + this.removeWhiteSpaceStripper(); // Get EOS token. const {tokens_to_ids} = this.tokenizer.model; this.eosToken = tokens_to_ids.get(this.tokenizer.getToken('eos_token')); @@ -85,4 +80,29 @@ export class Tokenizer { return tokens.slice(this.systemPromptLength); return tokens; } + + // Many tokenizer decoders strips the heading whitespaces for the output, + // which breaks our streaming output as whitespaces between tokens are lost. + private removeWhiteSpaceStripper() { + const decoder = this.tokenizer.decoder as any; + if (decoder.constructor.name == 'MetaspaceDecoder') { + if (decoder.prepend_scheme == 'always' && decoder.split == true) { + // Work around a bug of transformers.js: + // https://github.com/xenova/transformers.js/issues/959 + decoder.addPrefixSpace = true; + delete decoder.prepend_scheme; + delete decoder.split; + } + if (decoder.addPrefixSpace) { + this.trimLeft = true; + decoder.addPrefixSpace = false; + return; + } + } + if (decoder.decoders?.at(-1)?.config?.type == 'Strip') { + this.trimLeft = true; + decoder.decoders.pop(); + return; + } + } }