Skip to content

Commit

Permalink
Update the implementation of RotatingKVCache
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Oct 8, 2024
1 parent 1662f9a commit 5202fc6
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 89 deletions.
30 changes: 15 additions & 15 deletions src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,16 @@ export function baseModelArgs<T>(args: T): T {
}

/**
* Create an additive causal mask.
* Create causal mask.
*/
export function createAdditiveCausalMask(N: number, offset = 0) {
const rinds = mx.arange(offset + N);
const linds = offset ? mx.arange(offset, offset + N) : rinds;
const mask = mx.less(linds.index(mx.Slice(), mx.newaxis),
rinds.index(mx.newaxis));
export function createCausalMask(N: number, offset = 0, windowSize?: number) {
let rinds = mx.arange(offset + N);
let linds = offset ? mx.arange(offset, offset + N) : rinds;
rinds = rinds.index(mx.newaxis);
linds = linds.index(mx.Slice(), mx.newaxis);
let mask = mx.less(linds, rinds);
if (windowSize !== undefined)
mask = mx.bitwiseOr(mask, mx.greater(linds, mx.add(rinds, windowSize)));
return mx.multiply(mask, -1e9);
}

Expand All @@ -135,17 +138,14 @@ export function createAdditiveCausalMask(N: number, offset = 0) {
export function createAttentionMask(h: mx.array, cache?: BaseKVCache[]) {
const T = h.shape[1];
if (T > 1) {
let offset: number;
let windowSize: number | undefined;
let offset = 0;
if (cache) {
const c = cache[0];
if (c instanceof RotatingKVCache)
offset = Math.min(c.maxSize - 1, c.offset);
else
offset = c.offset;
} else {
offset = 0;
offset = cache[0].offset;
if (cache[0] instanceof RotatingKVCache)
windowSize = cache[0].maxSize;
}
return createAdditiveCausalMask(T, offset).astype(h.dtype);
return createCausalMask(T, offset, windowSize).astype(h.dtype);
} else {
return null;
}
Expand Down
161 changes: 100 additions & 61 deletions src/kv-cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ import {core as mx, nn} from '@frost-beta/mlx';

export interface KVCacheOptions {
nLayers: number;
headDim: number;
nKVHeads: number;
}

/**
Expand All @@ -12,23 +10,20 @@ export interface KVCacheOptions {
export abstract class BaseKVCache {
keys?: mx.array;
values?: mx.array;
offset = 0;
step = 256;

static create<T extends BaseKVCache>(
options: KVCacheOptions,
construct: new (headDim: number, nKVHeads: number) => T) {
static create<T extends BaseKVCache>(options: KVCacheOptions,
construct: new (...args: any[]) => T,
...args: any[]) {
const cache: BaseKVCache[] = [];
for (let i = 0; i < options.nLayers; ++i)
cache[i] = new construct(options.headDim, options.nKVHeads);
cache[i] = new construct(...args);
return cache;
}

abstract updateAndFetch(keys: mx.array, values: mx.array): [ mx.array, mx.array ];

get state() {
return [ this.keys, this.values ];
}
abstract get state(): any;
abstract get offset(): number;
}

/**
Expand All @@ -38,23 +33,22 @@ export abstract class BaseKVCache {
* See also https://github.com/ml-explore/mlx-examples/issues/724.
*/
export class KVCache extends BaseKVCache {
constructor(public headDim: number,
public nKVHeads: number) {
super();
}
#offset = 0;

static override create(options: KVCacheOptions) {
return BaseKVCache.create<KVCache>(options, KVCache);
}

override updateAndFetch(keys: mx.array, values: mx.array): [ mx.array, mx.array ] {
const prev = this.offset;
const prev = this.#offset;
if (!this.keys || (prev + keys.shape[2] > this.keys.shape[2])) {
const B = keys.shape[0];
const [ B, nKVHeads, , kHeadDim ] = keys.shape;
const vHeadDim = values.shape[3];
const nSteps = Math.floor((this.step + keys.shape[2] - 1) / this.step);
const shape = [ B, this.nKVHeads, nSteps * this.step, this.headDim ];
const newK = mx.zeros(shape, keys.dtype);
const newV = mx.zeros(shape, values.dtype);
const kShape = [ B, nKVHeads, nSteps * this.step, kHeadDim ];
const vShape = [ B, nKVHeads, nSteps * this.step, vHeadDim ];
const newK = mx.zeros(kShape, keys.dtype);
const newV = mx.zeros(vShape, values.dtype);
if (this.keys) {
const old = [ this.keys, this.values ];
if (prev % this.step != 0) {
Expand All @@ -71,15 +65,28 @@ export class KVCache extends BaseKVCache {
}
}

this.offset += keys.shape[2];
this.#offset += keys.shape[2];

const insert: mx.ArrayIndex[] = [ '...', mx.Slice(prev, this.offset), mx.Slice() ];
const insert: mx.ArrayIndex[] = [ '...', mx.Slice(prev, this.#offset), mx.Slice() ];
this.keys.indexPut_(insert, keys);
this.values.indexPut_(insert, values);

const get: mx.ArrayIndex[] = [ '...', mx.Slice(null, this.offset), mx.Slice() ];
const get: mx.ArrayIndex[] = [ '...', mx.Slice(null, this.#offset), mx.Slice() ];
return [ this.keys.index(...get), this.values.index(...get) ];
}

override get state() {
if (this.#offset == this.keys.shape[2]) {
return [ this.keys, this.values ];
} else {
const get: mx.ArrayIndex[] = [ '...', mx.Slice(null, this.#offset), mx.Slice() ];
return [ this.keys.index(...get), this.values.index(...get) ];
}
}

override get offset() {
return this.#offset;
}
}

/**
Expand All @@ -88,52 +95,69 @@ export class KVCache extends BaseKVCache {
* See also https://github.com/ml-explore/mlx-examples/pull/931.
*/
export class RotatingKVCache extends BaseKVCache {
kHeadDim: number;
vHeadDim: number;
#offset = 0;
#idx = 0;

static override create(options: KVCacheOptions) {
return BaseKVCache.create(options, RotatingKVCache);
static override create(options: KVCacheOptions, ...args: any[]) {
console.log(args);
return BaseKVCache.create(options, RotatingKVCache, ...args);
}

constructor(headDim: number,
public nKVHeads: number,
public maxSize = 1024,
public keep = 4) {
constructor(public maxSize = 1024, public keep = 4) {
super();
this.kHeadDim = this.vHeadDim = headDim;
}

override updateAndFetch(keys: mx.array, values: mx.array): [ mx.array, mx.array ] {
const prev = this.offset;
const [ B, , S ] = keys.shape;

// Prefill mode.
if (S > 1) {
if (!this.keys) {
this.keys = keys;
this.values = values;
} else {
// The largest size is this.maxSize + S - 1 to ensure every token gets
// at least this.maxSize context.
const trimSize = this.keys.shape[2] - this.maxSize + 1;
const old = [ this.keys, this.values ];
this.keys = this.trim(trimSize, this.keys, keys);
this.values = this.trim(trimSize, this.values, values);
mx.dispose(old);
}
this.offset += S;
this.#idx = this.keys.shape[2];
if (keys.shape[2] == 1)
return this.updateInPlace(keys, values);
else
return this.updateConcat(keys, values);
}

override get state() {
if (this.#offset < this.keys.shape[2]) {
const get: mx.ArrayIndex[] = [ '...', mx.Slice(null, this.#offset), mx.Slice() ];
return [ this.keys.index(...get), this.values.index(...get) ];
} else {
return [ this.keys, this.values ];
}
}

override get offset(): number {
return Math.min(this.maxSize - 1, this.#offset);
}

// Generation mode.
private updateConcat(keys: mx.array, values: mx.array): [ mx.array, mx.array ] {
if (!this.keys) {
this.keys = keys;
this.values = values;
} else {
// Put the keys/values in temporal order to preserve context.
const old = [ this.keys, this.values ];
this.keys = this.temporalOrder(this.keys);
this.values = this.temporalOrder(this.values);

// The largest size is self.max_size + S - 1 to ensure every token gets
// at least this.maxSize context.
const trimSize = this.#idx - this.maxSize + 1;
this.keys = this.trim(trimSize, this.keys, keys);
this.values = this.trim(trimSize, this.values, values);
mx.dispose(old);
}
this.#offset += keys.shape[2];
this.#idx = this.keys.shape[2];
return [ this.keys, this.values ];
}

// May not have hit the max size yet, so potentiall keep growing the cache.
private updateInPlace(keys: mx.array, values: mx.array): [ mx.array, mx.array ] {
// May not have hit the max size yet, so potentially keep growing the cache.
const [ B, nKVHeads, S, kHeadDim ] = keys.shape;
const prev = this.#offset;
if (!this.keys || (prev >= this.keys.shape[2] && this.keys.shape[2] < this.maxSize)) {
const vHeadDim = values.shape[3];
const newSize = Math.min(this.step, this.maxSize - prev);
const kShape = [ B, this.nKVHeads, newSize, this.kHeadDim ];
const vShape = [ B, this.nKVHeads, newSize, this.vHeadDim ];
const kShape = [ B, nKVHeads, newSize, kHeadDim ];
const vShape = [ B, nKVHeads, newSize, vHeadDim ];
const newK = mx.zeros(kShape, keys.dtype);
const newV = mx.zeros(vShape, values.dtype);
if (this.keys) {
Expand Down Expand Up @@ -164,15 +188,15 @@ export class RotatingKVCache extends BaseKVCache {
}

// Assign.
const insert: mx.ArrayIndex[] = [ '...', mx.Slice(this.#idx, this.#idx + 1), mx.Slice() ];
const insert: mx.ArrayIndex[] = [ '...', mx.Slice(this.#idx, this.#idx + S), mx.Slice() ];
this.keys.indexPut_(insert, keys);
this.values.indexPut_(insert, values);
this.offset += 1;
this.#idx += 1;
this.#offset += S;
this.#idx += S;

// If the buffer is not full, slice off the end.
if (this.offset < this.maxSize) {
const get: mx.ArrayIndex[] = [ '...', mx.Slice(null, this.offset), mx.Slice() ];
if (this.#offset < this.maxSize) {
const get: mx.ArrayIndex[] = [ '...', mx.Slice(null, this.#offset), mx.Slice() ];
return [ this.keys.index(...get), this.values.index(...get) ];
}
return [ this.keys, this.values ];
Expand All @@ -181,7 +205,7 @@ export class RotatingKVCache extends BaseKVCache {
private trim(trimSize: number, v: mx.array, append?: mx.array) {
let toCat: mx.array[];
if (trimSize > 0) {
toCat = [ v.index('...', mx.Slice(0, this.keep), mx.Slice()),
toCat = [ v.index('...', mx.Slice(null, this.keep), mx.Slice()),
v.index('...', mx.Slice(trimSize + this.keep), mx.Slice()) ];
} else {
toCat = [ v ];
Expand All @@ -191,4 +215,19 @@ export class RotatingKVCache extends BaseKVCache {
}
return mx.concatenate(toCat, 2);
}

// Rearrange the cache into temporal order, slicing off the end if unused.
private temporalOrder(v: mx.array) {
if (this.#idx == v.shape[2]) {
return v;
} else if (this.#idx < this.#offset) {
return mx.concatenate([
v.index('...', mx.Slice(null, this.keep), mx.Slice()),
v.index('...', mx.Slice(this.#idx), mx.Slice()),
v.index('...', mx.Slice(this.keep, this.#idx), mx.Slice()),
], 2);
} else {
return v.index('...', mx.Slice(null, this.#idx), mx.Slice());
}
}
}
16 changes: 13 additions & 3 deletions src/llm.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {fileURLToPath} from 'node:url';
import {core as mx} from '@frost-beta/mlx';
import {BaseModel, StepOptions, loadModel, step} from './base.js';
import {BaseKVCache, KVCache} from './kv-cache.js';
import {BaseKVCache, KVCache, RotatingKVCache} from './kv-cache.js';
import {ChatTemplateOptions, Message, Tokenizer} from './tokenizer.js';
import {ImageInputType, ImageProcessor} from './image-processor.js';

Expand All @@ -10,6 +10,7 @@ import {ImageInputType, ImageProcessor} from './image-processor.js';
*/
export interface LLMGenerateOptions extends StepOptions {
maxTokens?: number;
maxKVSize?: number;
}

/**
Expand All @@ -22,6 +23,10 @@ export function parseArgs(args: string[]): [ string[], LLMGenerateOptions ] {
options.maxTokens = parseInt(arg.substring(arg.indexOf('=') + 1));
return false;
}
if (arg.startsWith('--max-kv-size=')) {
options.maxKVSize = parseInt(arg.substring(arg.indexOf('=') + 1));
return false;
}
if (arg.startsWith('--temperature=')) {
options.temperature = parseFloat(arg.substring(arg.indexOf('=') + 1));
return false;
Expand Down Expand Up @@ -98,8 +103,13 @@ export class LLM {
this.model.eval();
// If not specified, create a shared cache between generations.
if (!options.kvCache) {
if (!this.kvCache)
this.kvCache = KVCache.create(this.model.getDecoderKVCacheOptions());
if (!this.kvCache) {
const kvCacheOptions = this.model.getDecoderKVCacheOptions();
if (options.maxKVSize)
this.kvCache = RotatingKVCache.create(kvCacheOptions, options.maxKVSize);
else
this.kvCache = KVCache.create(kvCacheOptions);
}
options.kvCache = this.kvCache;
}
// Predict next tokens.
Expand Down
6 changes: 1 addition & 5 deletions src/models/llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,6 @@ export class Model extends BaseModel {
}

override getDecoderKVCacheOptions(): KVCacheOptions {
return {
nLayers: this.model.layers.length,
headDim: this.args.hiddenSize / this.args.numAttentionHeads,
nKVHeads: this.args.numKeyValueHeads,
};
return {nLayers: this.model.layers.length};
}
}
6 changes: 1 addition & 5 deletions src/models/t5.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,6 @@ export class Model extends BaseModel {
}

override getDecoderKVCacheOptions(): KVCacheOptions {
return {
nLayers: this.decoder.block.length,
headDim: this.args.dKv,
nKVHeads: this.args.numHeads,
};
return {nLayers: this.decoder.block.length};
}
}

0 comments on commit 5202fc6

Please sign in to comment.