Skip to content

Commit

Permalink
Make BaseModel aware of encoders
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Oct 3, 2024
1 parent 1344ddc commit 6268207
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 67 deletions.
71 changes: 47 additions & 24 deletions src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ import {loadWeights, readJsonSync} from './fs.js';
* The base class of LLM models with or without vision.
*/
export abstract class BaseModel extends nn.Module {
/**
* Whether this model has an encoder.
*/
hasEncoder = false;
/**
* The token that used as the first token for decoder.
*/
decoderStartToken?: number;
/**
* The special token representing image, usually <image> or <|image|>.
*/
Expand All @@ -30,13 +38,20 @@ export abstract class BaseModel extends nn.Module {
/**
* Predict next token for the embeddings.
*/
abstract forwardEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array;
abstract decodeEmbeddings(embeddings: mx.array, memory?: mx.array, cache?: BaseKVCache[]): mx.array;

/**
* Pass the text embeddings to encoder and return memory.
*/
encodeTextEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
throw new Error('This model has no encoder.');
}

/**
* Predict next token for the encoded text.
*/
override forward(inputs: mx.array, cache?: BaseKVCache[]): mx.array {
return this.forwardEmbeddings(this.computeTextEmbeddings(inputs), cache);
override forward(inputs: mx.array, memory?: mx.array, cache?: BaseKVCache[]): mx.array {
return this.decodeEmbeddings(this.computeTextEmbeddings(inputs), memory, cache);
}

/**
Expand Down Expand Up @@ -209,26 +224,34 @@ export async function* step(promptEmbeds: mx.array,
return token.item() as number;
};

// Forward prompt by steps so we don't use too much RAM.
// See also https://github.com/ml-explore/mlx-examples/pull/931
let nextToken = eosToken;
const prefillStepSize = 512;
const embeddingsSize = promptEmbeds.shape[1];
for (let offset = 0; offset < embeddingsSize;) {
if (signal?.aborted)
break;
await mx.tidy(async () => {
const size = Math.min(prefillStepSize, embeddingsSize - offset);
const chunk = promptEmbeds.index(mx.Slice(), mx.Slice(offset, offset + size));
const logits = model.forwardEmbeddings(chunk, cache);
mx.eval(cache.map(c => c.state));
offset += size;
// Do token-by-token generation after prompt is consumed.
if (offset == embeddingsSize)
nextToken = await predict(logits);
// Keep the cache from being released.
return cache;
});
// Handle prompt: for encoder-decoder models we pass it to encoder, fo
// decoder-only models we pass it to decoder directly.
let nextToken: number;
let memory: mx.array | undefined;
if (model.hasEncoder) {
nextToken = model.decoderStartToken;
memory = model.encodeTextEmbeddings(promptEmbeds);
} else {
// Forward prompt by steps so we don't use too much RAM.
// See also https://github.com/ml-explore/mlx-examples/pull/931
const prefillStepSize = 512;
const embeddingsSize = promptEmbeds.shape[1];
for (let offset = 0; offset < embeddingsSize;) {
if (signal?.aborted)
break;
await mx.tidy(async () => {
const size = Math.min(prefillStepSize, embeddingsSize - offset);
const chunk = promptEmbeds.index(mx.Slice(), mx.Slice(offset, offset + size));
const logits = model.decodeEmbeddings(chunk, undefined, cache);
mx.eval(cache.map(c => c.state));
offset += size;
// Do token-by-token generation after prompt is consumed.
if (offset == embeddingsSize)
nextToken = await predict(logits);
// Keep the cache from being released.
return cache;
});
}
}

do {
Expand All @@ -241,7 +264,7 @@ export async function* step(promptEmbeds: mx.array,
yield nextToken;
// Forward the token to model and free intermediate tensors.
[ nextToken ] = await mx.tidy(async () => {
const logits = model.forward(mx.array([ [ nextToken ] ], mx.int32), cache);
const logits = model.forward(mx.array([ [ nextToken ] ], mx.int32), memory, cache);
// The cache is also returned so it does not get freed by mx.tidy().
return [ await predict(logits), cache ];
});
Expand Down
3 changes: 2 additions & 1 deletion src/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export class LLM {
tokens = this.tokenizer.encode(text);
// Some tokenizers append EOS to the encoded text, remove it otherwise the
// generation might stop there.
if (tokens.length > 1 && tokens.at(-1) == this.tokenizer.eosToken)
if (!this.model.hasEncoder && tokens.length > 1 && tokens.at(-1) == this.tokenizer.eosToken)
tokens.pop();
} else {
tokens = [ this.tokenizer.bosToken ];
Expand Down Expand Up @@ -95,6 +95,7 @@ export class LLM {
* Predict next tokens using the embeddings of prompt.
*/
async *generate(promptEmbeds: mx.array, options: LLMGenerateOptions = {}) {
this.model.eval();
// If not specified, create a shared cache between generations.
if (!options.kvCache) {
if (!this.kvCache)
Expand Down
7 changes: 4 additions & 3 deletions src/models/llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,8 @@ export class Model extends BaseModel {
lmHead: nn.Linear;

constructor(json: any) {
const args = modelArgs(json);
super();

const args = modelArgs(json);
this.args = args;
this.model = new LlamaModel(args);
if (!args.tieWordEmbeddings)
Expand All @@ -274,7 +273,9 @@ export class Model extends BaseModel {
return this.model.embedTokens.forward(inputs);
}

override forwardEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
override decodeEmbeddings(embeddings: mx.array, memory?: mx.array, cache?: BaseKVCache[]): mx.array {
if (memory)
throw new Error('This model has no encoder.');
const out = this.model.forward(embeddings, cache);
if (this.args.tieWordEmbeddings)
return this.model.embedTokens.asLinear(out);
Expand Down
14 changes: 8 additions & 6 deletions src/models/llava.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,18 @@ export class Model extends BaseModel {
visionFeatureLayer: number;
visionFeatureSelectStrategy: string;

constructor(args: ModelArgs) {
constructor(json: any) {
super();
args = modelArgs(args);
this.imagePlaceholder = '<image>';
this.imageToken = args.imageTokenIndex;
const args = modelArgs(json);
this.visionTower = new VisionModel(args.visionConfig);
this.languageModel = new llama.Model(args.textConfig);
this.multiModalProjector = new LlavaMultiModalProjector(args);
this.visionFeatureLayer = args.visionFeatureLayer;
this.visionFeatureSelectStrategy = args.visionFeatureSelectStrategy;

this.hasEncoder = this.languageModel.hasEncoder;
this.imagePlaceholder = '<image>';
this.imageToken = args.imageTokenIndex;
}

override computePixelEmbeddings(pixels: mx.array): mx.array {
Expand All @@ -100,8 +102,8 @@ export class Model extends BaseModel {
return this.languageModel.computeTextEmbeddings(inputs);
}

override forwardEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
return this.languageModel.forwardEmbeddings(embeddings, cache);
override decodeEmbeddings(embeddings: mx.array, memory?: mx.array, cache?: BaseKVCache[]): mx.array {
return this.languageModel.decodeEmbeddings(embeddings, memory, cache);
}

override sanitize(weights: Record<string, mx.array>) {
Expand Down
85 changes: 52 additions & 33 deletions src/models/t5.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ export interface ModelArgs {
dFf: number;
dKv: number;
dModel: number;
decoderStartTokenId: number;
denseActFn: string;
dropoutRate: number;
eosTokenId: number;
feedForwardProj: string;
initializerFactor: number;
isEncoderDecoder: boolean;
isGatedAct: boolean;
layerNormEpsilon: number;
numDecoderLayers?: number;
numDecoderLayers: number;
numHeads: number;
numLayers: number;
padTokenId: number;
relativeAttentionMaxDistance: number;
relativeAttentionNumBuckets: number;
useCache: boolean;
tieWordEmbeddings: boolean;
vocabSize: number;
}

Expand All @@ -35,16 +35,23 @@ export function modelArgs(json: any): ModelArgs {
eosTokenId: 1,
feedForwardProj: 'relu',
initializerFactor: 1.0,
isEncoderDecoder: true,
layerNormEpsilon: 1e-6,
numHeads: 8,
numLayers: 6,
padTokenId: 0,
relativeAttentionMaxDistance: 128,
relativeAttentionNumBuckets: 32,
useCache: true,
tieWordEmbeddings: true,
vocabSize: 32128,
}, baseModelArgs(json));
if (args.decoderStartTokenId === undefined) {
args.decoderStartTokenId = args.padTokenId;
if (args.decoderStartTokenId === undefined)
throw new Error('Must provide "decoder_start_token_id" or "pad_token_id"');
}
if (!args.numDecoderLayers) {
args.numDecoderLayers = args.numLayers;
}
args.denseActFn = args.feedForwardProj.split('-').at(-1);
args.isGatedAct = args.feedForwardProj.startsWith('gated-');
return args;
Expand All @@ -59,8 +66,8 @@ class RelativeAttentionBias extends nn.Module {
}

forward(queryLength: number, keyLength: number, offset = 0) {
const contextPosition = mx.arange(offset, queryLength).index(mx.Slice(), mx.newaxis);
const memoryPosition = mx.arange(keyLength).index(mx.newaxis, mx.Slice());
const contextPosition = mx.arange(offset, queryLength, 1, mx.int16).index(mx.Slice(), mx.newaxis);
const memoryPosition = mx.arange(keyLength, mx.int16).index(mx.newaxis, mx.Slice());

const relativePosition = mx.subtract(memoryPosition, contextPosition);
const relativePositionBucket = this.relativePositionBucket(
Expand Down Expand Up @@ -93,7 +100,8 @@ class RelativeAttentionBias extends nn.Module {
mx.multiply(mx.log(mx.divide(relativePosition.astype(mx.float32),
maxExact)),
scale).astype(mx.int16));
relativePositionIfLarge = mx.minimum(relativePositionIfLarge, numBuckets - 1);
relativePositionIfLarge = mx.minimum(relativePositionIfLarge,
mx.array(numBuckets - 1, mx.int16));
relativeBuckets = mx.add(relativeBuckets,
mx.where(isSmall, relativePosition, relativePositionIfLarge));
return relativeBuckets;
Expand Down Expand Up @@ -202,15 +210,17 @@ class Attention extends nn.Module {
const [ , S, ] = keys.shape;

queries = queries.reshape(B, L, numHeads, -1).transpose(0, 2, 1, 3);
keys = keys.reshape(B, S, numHeads, -1).transpose(0, 2, 3, 1);
keys = keys.reshape(B, S, numHeads, -1).transpose(0, 2, 1, 3);
values = values.reshape(B, S, numHeads, -1).transpose(0, 2, 1, 3);

if (cache)
[ keys, values ] = cache.updateAndFetch(keys, values);

const scale = Math.sqrt(1 / queries.shape.at(-1));
let output = mx.fast.scaledDotProductAttention(queries.astype(mx.float32), keys, values, scale, mask).astype(values.dtype);
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1);
let scores = mx.matmul(queries, keys.transpose(0, 1, 3, 2));
if (mask)
scores = mx.add(scores, mask);
scores = mx.softmax(scores.astype(mx.float32), -1).astype(scores.dtype);
const output = mx.matmul(scores, values).transpose(0, 2, 1, 3).reshape(B, L, -1);
return this.o.forward(output);
}
}
Expand Down Expand Up @@ -280,15 +290,15 @@ class EncoderBlock extends nn.Module {
}

class Encoder extends nn.Module {
layers: EncoderBlock[] = [];
block: EncoderBlock[] = [];
ln: nn.RMSNorm;
relativeAttentionBias: RelativeAttentionBias;
dropout: nn.Dropout;

constructor(args: ModelArgs) {
super();
for (let i = 0; i < args.numLayers; ++i)
this.layers.push(new EncoderBlock(args));
this.block.push(new EncoderBlock(args));
this.ln = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.relativeAttentionBias = new RelativeAttentionBias(args, true);
this.dropout = new nn.Dropout(args.dropoutRate);
Expand All @@ -297,7 +307,7 @@ class Encoder extends nn.Module {
forward(x: mx.array) {
const L = x.shape[1];
const positionBias = this.relativeAttentionBias.forward(L, L);
for (const layer of this.layers)
for (const layer of this.block)
x = layer.forward(x, positionBias);
x = this.ln.forward(x);
x = this.dropout.forward(x);
Expand Down Expand Up @@ -332,22 +342,22 @@ class DecoderBlock extends nn.Module {
y = this.crossAttention.forward(y, memory, memory, memoryMask);
x = mx.add(x, y);
y = this.ln3.forward(x);
y = this.dense.forward(x);
y = this.dense.forward(y);
x = mx.add(x, y);
return x;
}
}

class Decoder extends nn.Module {
layers: DecoderBlock[] = [];
block: DecoderBlock[] = [];
ln: nn.RMSNorm;
relativeAttentionBias: RelativeAttentionBias;
dropout: nn.Dropout;

constructor(args: ModelArgs) {
super();
for (let i = 0; i < args.numDecoderLayers; ++i)
this.layers.push(new DecoderBlock(args));
this.block.push(new DecoderBlock(args));
this.ln = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.relativeAttentionBias = new RelativeAttentionBias(args, false);
this.dropout = new nn.Dropout(args.dropoutRate);
Expand All @@ -361,8 +371,8 @@ class Decoder extends nn.Module {
mask = mx.add(mask, positionBias);
else
mask = positionBias;
for (let i in this.layers)
x = this.layers[i].forward(x, memory, memoryMask, mask, cache ? cache[i] : undefined);
for (let i in this.block)
x = this.block[i].forward(x, memory, mask, memoryMask, cache ? cache[i] : undefined);
x = this.ln.forward(x);
x = this.dropout.forward(x);
return x;
Expand All @@ -374,29 +384,44 @@ export class Model extends BaseModel {
shared: nn.Embedding;
encoder: Encoder;
decoder: Decoder;
lmHead: nn.Linear;
lmHead?: nn.Linear;

constructor(json: any) {
const args = modelArgs(json);
super();

const args = modelArgs(json);
this.args = args;
this.shared = new nn.Embedding(args.vocabSize, args.dModel);
this.encoder = new Encoder(args);
this.decoder = new Decoder(args);
this.lmHead = new nn.Linear(args.dModel, args.vocabSize, false);
if (!args.tieWordEmbeddings)
this.lmHead = new nn.Linear(args.dModel, args.vocabSize, false);

this.hasEncoder = true;
this.decoderStartToken = args.decoderStartTokenId;
}

override computeTextEmbeddings(inputs: mx.array): mx.array {
return this.shared.forward(inputs);
}

override forwardEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
return this.decode(mx.array([ 0 ], mx.int16), this.encoder.forward(embeddings), cache);
override decodeEmbeddings(embeddings: mx.array, memory: mx.array, cache?: BaseKVCache[]): mx.array {
if (!memory)
throw new Error('This model is not decoder-only.');
const mask = createAttentionMask(embeddings, cache);
let y = this.decoder.forward(embeddings, memory, mask, undefined, cache);
if (this.lmHead)
return this.lmHead.forward(y);
y = mx.multiply(y, this.args.dModel ** -0.5);
y = mx.matmul(y, this.shared.weight.T);
return y;
}

override encodeTextEmbeddings(embeddings: mx.array, cache?: BaseKVCache[]): mx.array {
return this.encoder.forward(embeddings);
}

override get layers() {
return this.decoder.layers;
return this.decoder.block;
}

override get headDim() {
Expand All @@ -406,10 +431,4 @@ export class Model extends BaseModel {
override get nKVHeads() {
return this.args.numHeads;
}

decode(inputs: mx.array, memory: mx.array, cache?: BaseKVCache[]) {
const mask = createAttentionMask(inputs, cache);
const y = this.decoder.forward(inputs, memory, mask, undefined, cache);
return this.lmHead.forward(y);
}
}

0 comments on commit 6268207

Please sign in to comment.