diff --git a/README.md b/README.md index 462654d..d16a9c2 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ export class LLM { /** * Predict next tokens using the embeddings of prompt. */ - async *generate(promptEmbeds: mx.array, options?: LLMGenerateOptions): AsyncGenerator; + async *generate(promptEmbeds: mx.array, options?: LLMGenerateOptions): AsyncGenerator; } /** diff --git a/src/base.ts b/src/base.ts index 569f2c0..42de5c3 100644 --- a/src/base.ts +++ b/src/base.ts @@ -210,7 +210,9 @@ export async function* step(promptEmbeds: mx.array, kvCache, topP = 0.8, temperature = 1, - }: StepOptions = {}): AsyncGenerator { + }: StepOptions = {}): AsyncGenerator { + const [ batchSize, embeddingsSize ] = promptEmbeds.shape; + // Create KV Cache if none is specified in options. const cache = kvCache ?? RotatingKVCache.create(model.getDecoderKVCacheOptions()); @@ -219,22 +221,21 @@ export async function* step(promptEmbeds: mx.array, logits = logits.index(mx.Slice(), -1); const [ token ] = sample(logits, topP, temperature); await mx.asyncEval(token); - return token.item() as number; + return token.tolist() as number[]; }; // Handle prompt: for encoder-decoder models we pass it to encoder, fo // decoder-only models we pass it to decoder directly. - let nextToken: number; + let nextTokens: number[]; let memory: mx.array | undefined; if (model.hasEncoder) { - nextToken = model.decoderStartToken; + nextTokens = new Array(batchSize).fill(model.decoderStartToken); memory = model.encodeEmbeddings(promptEmbeds); mx.metal.clearCache(); } else { // Forward prompt by steps so we don't use too much RAM. // See also https://github.com/ml-explore/mlx-examples/pull/931 const prefillStepSize = 512; - const embeddingsSize = promptEmbeds.shape[1]; for (let offset = 0; offset < embeddingsSize;) { if (signal?.aborted) break; @@ -246,7 +247,7 @@ export async function* step(promptEmbeds: mx.array, offset += size; // Do token-by-token generation after prompt is consumed. if (offset == embeddingsSize) - nextToken = await predict(logits); + nextTokens = await predict(logits); // Keep the cache from being released. return cache; }); @@ -256,17 +257,17 @@ export async function* step(promptEmbeds: mx.array, do { // Quit after getting EOS. - if (nextToken == eosToken) + if (nextTokens.indexOf(eosToken) > -1) break; // The generation is aborted. if (signal?.aborted) break; // Do not yield token if it is the decoderStartToken. - if (!(model.hasEncoder && nextToken == model.decoderStartToken)) - yield nextToken; + if (!(model.hasEncoder && nextTokens.indexOf(model.decoderStartToken) > -1)) + yield nextTokens; // Forward the token to model and free intermediate tensors. - [ nextToken ] = await mx.tidy(async () => { - const logits = model.forward(mx.array([ [ nextToken ] ], mx.int32), memory, cache); + [ nextTokens ] = await mx.tidy(async () => { + const logits = model.forward(mx.array(nextTokens, mx.int32).index(mx.Slice(), mx.newaxis), memory, cache); // The cache is also returned so it does not get freed by mx.tidy(). return [ await predict(logits), cache ]; }); diff --git a/src/chat.ts b/src/chat.ts index 26d7b9e..dc14930 100755 --- a/src/chat.ts +++ b/src/chat.ts @@ -66,7 +66,7 @@ async function talk(rl: readline.Interface, // Predict next tokens. let result = ''; - for await (const text of llm.generate(promptEmbeds, options)) { + for await (const [ text ] of llm.generate(promptEmbeds, options)) { result += text; process.stdout.write(text); } diff --git a/src/generate.ts b/src/generate.ts index b432b34..c6b0fb4 100755 --- a/src/generate.ts +++ b/src/generate.ts @@ -14,7 +14,7 @@ main(argv[0], argv[1]); async function main(dir: string, prompt?: string) { const llm = await loadLLM(dir); const promptEmbeds = await llm.encode(prompt); - for await (const text of llm.generate(promptEmbeds, options)) + for await (const [ text ] of llm.generate(promptEmbeds, options)) process.stdout.write(text); process.stdout.write('\n'); } diff --git a/src/llm.ts b/src/llm.ts index 39c6a7e..1279b8f 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -99,6 +99,7 @@ export class LLM { * Predict next tokens using the embeddings of prompt. */ async *generate(promptEmbeds: mx.array, options: LLMGenerateOptions = {}) { + const [ batchSize ] = promptEmbeds.shape; // If not specified, create a shared cache between generations. if (!options.kvCache) { if (!this.kvCache) { @@ -111,22 +112,26 @@ export class LLM { options.kvCache = this.kvCache; } // Predict next tokens. - let buffer: number[] = []; + let buffers: number[][] = Array.from({length: batchSize}, () => []); let count = 0; - for await (const token of step(promptEmbeds, this.model, this.tokenizer.eosToken, options)) { + for await (const tokens of step(promptEmbeds, this.model, this.tokenizer.eosToken, options)) { ++count; if (options.maxTokens && count > options.maxTokens) break; - buffer.push(token); - let text = this.tokenizer.decode(buffer); - // The token may represent an incomplete unicode char. - if (text.endsWith('\u{FFFD}')) - continue; - // Trim left whitespace for the first output. - if (this.tokenizer.trimLeft && count == 1) - text = text.trimLeft(); - yield text; - buffer = []; + const results: string[] = Array.from({length: batchSize}, () => ''); + for (let i = 0; i < batchSize; ++ i) { + buffers[i].push(tokens[i]); + let text = this.tokenizer.decode(buffers[i]); + // The token may represent an incomplete unicode char. + if (text.endsWith('\u{FFFD}')) + continue; + // Trim left whitespace for the first output. + if (this.tokenizer.trimLeft && count == 1) + text = text.trimLeft(); + results[i] = text; + buffers[i] = []; + } + yield results; } }