Skip to content

Commit

Permalink
Handle batches in llm.generate API
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Oct 21, 2024
1 parent 609082e commit 7211d48
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export class LLM {
/**
* Predict next tokens using the embeddings of prompt.
*/
async *generate(promptEmbeds: mx.array, options?: LLMGenerateOptions): AsyncGenerator<string, void, unknown>;
async *generate(promptEmbeds: mx.array, options?: LLMGenerateOptions): AsyncGenerator<string[], void, unknown>;
}

/**
Expand Down
23 changes: 12 additions & 11 deletions src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ export async function* step(promptEmbeds: mx.array,
kvCache,
topP = 0.8,
temperature = 1,
}: StepOptions = {}): AsyncGenerator<number, void> {
}: StepOptions = {}): AsyncGenerator<number[], void> {
const [ batchSize, embeddingsSize ] = promptEmbeds.shape;

// Create KV Cache if none is specified in options.
const cache = kvCache ?? RotatingKVCache.create(model.getDecoderKVCacheOptions());

Expand All @@ -219,22 +221,21 @@ export async function* step(promptEmbeds: mx.array,
logits = logits.index(mx.Slice(), -1);
const [ token ] = sample(logits, topP, temperature);
await mx.asyncEval(token);
return token.item() as number;
return token.tolist() as number[];
};

// Handle prompt: for encoder-decoder models we pass it to encoder, fo
// decoder-only models we pass it to decoder directly.
let nextToken: number;
let nextTokens: number[];
let memory: mx.array | undefined;
if (model.hasEncoder) {
nextToken = model.decoderStartToken;
nextTokens = new Array(batchSize).fill(model.decoderStartToken);
memory = model.encodeEmbeddings(promptEmbeds);
mx.metal.clearCache();
} else {
// Forward prompt by steps so we don't use too much RAM.
// See also https://github.com/ml-explore/mlx-examples/pull/931
const prefillStepSize = 512;
const embeddingsSize = promptEmbeds.shape[1];
for (let offset = 0; offset < embeddingsSize;) {
if (signal?.aborted)
break;
Expand All @@ -246,7 +247,7 @@ export async function* step(promptEmbeds: mx.array,
offset += size;
// Do token-by-token generation after prompt is consumed.
if (offset == embeddingsSize)
nextToken = await predict(logits);
nextTokens = await predict(logits);
// Keep the cache from being released.
return cache;
});
Expand All @@ -256,17 +257,17 @@ export async function* step(promptEmbeds: mx.array,

do {
// Quit after getting EOS.
if (nextToken == eosToken)
if (nextTokens.indexOf(eosToken) > -1)
break;
// The generation is aborted.
if (signal?.aborted)
break;
// Do not yield token if it is the decoderStartToken.
if (!(model.hasEncoder && nextToken == model.decoderStartToken))
yield nextToken;
if (!(model.hasEncoder && nextTokens.indexOf(model.decoderStartToken) > -1))
yield nextTokens;
// Forward the token to model and free intermediate tensors.
[ nextToken ] = await mx.tidy(async () => {
const logits = model.forward(mx.array([ [ nextToken ] ], mx.int32), memory, cache);
[ nextTokens ] = await mx.tidy(async () => {
const logits = model.forward(mx.array(nextTokens, mx.int32).index(mx.Slice(), mx.newaxis), memory, cache);
// The cache is also returned so it does not get freed by mx.tidy().
return [ await predict(logits), cache ];
});
Expand Down
2 changes: 1 addition & 1 deletion src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async function talk(rl: readline.Interface,

// Predict next tokens.
let result = '';
for await (const text of llm.generate(promptEmbeds, options)) {
for await (const [ text ] of llm.generate(promptEmbeds, options)) {
result += text;
process.stdout.write(text);
}
Expand Down
2 changes: 1 addition & 1 deletion src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ main(argv[0], argv[1]);
async function main(dir: string, prompt?: string) {
const llm = await loadLLM(dir);
const promptEmbeds = await llm.encode(prompt);
for await (const text of llm.generate(promptEmbeds, options))
for await (const [ text ] of llm.generate(promptEmbeds, options))
process.stdout.write(text);
process.stdout.write('\n');
}
29 changes: 17 additions & 12 deletions src/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ export class LLM {
* Predict next tokens using the embeddings of prompt.
*/
async *generate(promptEmbeds: mx.array, options: LLMGenerateOptions = {}) {
const [ batchSize ] = promptEmbeds.shape;
// If not specified, create a shared cache between generations.
if (!options.kvCache) {
if (!this.kvCache) {
Expand All @@ -111,22 +112,26 @@ export class LLM {
options.kvCache = this.kvCache;
}
// Predict next tokens.
let buffer: number[] = [];
let buffers: number[][] = Array.from({length: batchSize}, () => []);
let count = 0;
for await (const token of step(promptEmbeds, this.model, this.tokenizer.eosToken, options)) {
for await (const tokens of step(promptEmbeds, this.model, this.tokenizer.eosToken, options)) {
++count;
if (options.maxTokens && count > options.maxTokens)
break;
buffer.push(token);
let text = this.tokenizer.decode(buffer);
// The token may represent an incomplete unicode char.
if (text.endsWith('\u{FFFD}'))
continue;
// Trim left whitespace for the first output.
if (this.tokenizer.trimLeft && count == 1)
text = text.trimLeft();
yield text;
buffer = [];
const results: string[] = Array.from({length: batchSize}, () => '');
for (let i = 0; i < batchSize; ++ i) {
buffers[i].push(tokens[i]);
let text = this.tokenizer.decode(buffers[i]);
// The token may represent an incomplete unicode char.
if (text.endsWith('\u{FFFD}'))
continue;
// Trim left whitespace for the first output.
if (this.tokenizer.trimLeft && count == 1)
text = text.trimLeft();
results[i] = text;
buffers[i] = [];
}
yield results;
}
}

Expand Down

0 comments on commit 7211d48

Please sign in to comment.