Skip to content

Commit

Permalink
Decouple BaseModel and BaseKVCache
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Oct 7, 2024
1 parent 174fc99 commit 1662f9a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 52 deletions.
14 changes: 6 additions & 8 deletions src/base.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx';
import {BaseKVCache, RotatingKVCache} from './kv-cache.js';
import {KVCacheOptions, BaseKVCache, RotatingKVCache} from './kv-cache.js';
import {loadWeights, readJsonSync} from './fs.js';

/**
Expand Down Expand Up @@ -43,7 +43,7 @@ export abstract class BaseModel extends nn.Module {
/**
* Pass the text embeddings to encoder and return memory.
*/
encodeTextEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
encodeEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
throw new Error('This model has no encoder.');
}

Expand Down Expand Up @@ -98,10 +98,8 @@ export abstract class BaseModel extends nn.Module {
return mx.concatenate(segments, 1);
}

// Following properties are defined for internal KV cache use only.
abstract get layers(): nn.Module[];
abstract get headDim(): number;
abstract get nKVHeads(): number;
// Following methods are defined for internal KV cache use only.
abstract getDecoderKVCacheOptions(): KVCacheOptions;
}

/**
Expand Down Expand Up @@ -214,7 +212,7 @@ export async function* step(promptEmbeds: mx.array,
temperature = 1,
}: StepOptions = {}): AsyncGenerator<number, void> {
// Create KV Cache if none is specified in options.
const cache = kvCache ?? RotatingKVCache.createForModel(model);
const cache = kvCache ?? RotatingKVCache.create(model.getDecoderKVCacheOptions());

// Sample the logits results.
const predict = async (logits: mx.array) => {
Expand All @@ -230,7 +228,7 @@ export async function* step(promptEmbeds: mx.array,
let memory: mx.array | undefined;
if (model.hasEncoder) {
nextToken = model.decoderStartToken;
memory = model.encodeTextEmbeddings(promptEmbeds);
memory = model.encodeEmbeddings(promptEmbeds);
} else {
// Forward prompt by steps so we don't use too much RAM.
// See also https://github.com/ml-explore/mlx-examples/pull/931
Expand Down
23 changes: 14 additions & 9 deletions src/kv-cache.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import {core as mx, nn} from '@frost-beta/mlx';
import type {BaseModel} from './base.js';

export interface KVCacheOptions {
nLayers: number;
headDim: number;
nKVHeads: number;
}

/**
* The base class of KV cache.
Expand All @@ -10,12 +15,12 @@ export abstract class BaseKVCache {
offset = 0;
step = 256;

static createForModel<T extends BaseKVCache>(
model: BaseModel,
static create<T extends BaseKVCache>(
options: KVCacheOptions,
construct: new (headDim: number, nKVHeads: number) => T) {
const cache: BaseKVCache[] = [];
for (let i = 0; i < model.layers.length; ++i)
cache[i] = new construct(model.headDim, model.nKVHeads);
for (let i = 0; i < options.nLayers; ++i)
cache[i] = new construct(options.headDim, options.nKVHeads);
return cache;
}

Expand All @@ -38,8 +43,8 @@ export class KVCache extends BaseKVCache {
super();
}

static override createForModel(model: BaseModel) {
return BaseKVCache.createForModel<KVCache>(model, KVCache);
static override create(options: KVCacheOptions) {
return BaseKVCache.create<KVCache>(options, KVCache);
}

override updateAndFetch(keys: mx.array, values: mx.array): [ mx.array, mx.array ] {
Expand Down Expand Up @@ -87,8 +92,8 @@ export class RotatingKVCache extends BaseKVCache {
vHeadDim: number;
#idx = 0;

static override createForModel(model: BaseModel) {
return BaseKVCache.createForModel(model, RotatingKVCache);
static override create(options: KVCacheOptions) {
return BaseKVCache.create(options, RotatingKVCache);
}

constructor(headDim: number,
Expand Down
2 changes: 1 addition & 1 deletion src/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export class LLM {
// If not specified, create a shared cache between generations.
if (!options.kvCache) {
if (!this.kvCache)
this.kvCache = KVCache.createForModel(this.model);
this.kvCache = KVCache.create(this.model.getDecoderKVCacheOptions());
options.kvCache = this.kvCache;
}
// Predict next tokens.
Expand Down
18 changes: 7 additions & 11 deletions src/models/llama.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {core as mx, nn} from '@frost-beta/mlx';
import {BaseModel, baseModelArgs, createAttentionMask} from '../base.js';
import {BaseKVCache} from '../kv-cache.js';
import {KVCacheOptions, BaseKVCache} from '../kv-cache.js';

export interface RopeScaling {
type?: string;
Expand Down Expand Up @@ -283,15 +283,11 @@ export class Model extends BaseModel {
return this.lmHead.forward(out);
}

override get layers() {
return this.model.layers;
}

override get headDim() {
return Math.floor(this.args.hiddenSize / this.args.numAttentionHeads);
}

override get nKVHeads() {
return this.args.numKeyValueHeads;
override getDecoderKVCacheOptions(): KVCacheOptions {
return {
nLayers: this.model.layers.length,
headDim: this.args.hiddenSize / this.args.numAttentionHeads,
nKVHeads: this.args.numKeyValueHeads,
};
}
}
14 changes: 3 additions & 11 deletions src/models/llava.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {core as mx, nn} from '@frost-beta/mlx';
import {BaseModel, baseModelArgs} from '../base.js';
import {BaseKVCache} from '../kv-cache.js';
import {KVCacheOptions, BaseKVCache} from '../kv-cache.js';
import {VisionConfig, VisionModel} from './llava/vision.js';
import * as llama from './llama.js';

Expand Down Expand Up @@ -121,15 +121,7 @@ export class Model extends BaseModel {
}
}

override get layers() {
return this.languageModel.layers as nn.Module[];
}

override get headDim() {
return this.languageModel.headDim;
}

override get nKVHeads() {
return this.languageModel.nKVHeads;
override getDecoderKVCacheOptions(): KVCacheOptions {
return this.languageModel.getDecoderKVCacheOptions();
}
}
20 changes: 8 additions & 12 deletions src/models/t5.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {core as mx, nn} from '@frost-beta/mlx';
import {BaseModel, baseModelArgs, createAttentionMask} from '../base.js';
import {BaseKVCache} from '../kv-cache.js';
import {KVCacheOptions, BaseKVCache} from '../kv-cache.js';

export interface ModelArgs {
classifierDropout: number;
Expand Down Expand Up @@ -393,19 +393,15 @@ export class Model extends BaseModel {
return y;
}

override encodeTextEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
override encodeEmbeddings(embeddings: mx.array): mx.array {
return this.encoder.forward(embeddings);
}

override get layers() {
return this.decoder.block;
}

override get headDim() {
return this.args.dKv;
}

override get nKVHeads() {
return this.args.numHeads;
override getDecoderKVCacheOptions(): KVCacheOptions {
return {
nLayers: this.decoder.block.length,
headDim: this.args.dKv,
nKVHeads: this.args.numHeads,
};
}
}

0 comments on commit 1662f9a

Please sign in to comment.