Skip to content

Commit

Permalink
Be compatible with official T5 models
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Oct 4, 2024
1 parent 6268207 commit f87cda2
Showing 1 changed file with 50 additions and 73 deletions.
123 changes: 50 additions & 73 deletions src/models/t5.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,11 @@ export function modelArgs(json: any): ModelArgs {
}

class RelativeAttentionBias extends nn.Module {
embeddings: nn.Embedding;

constructor(public args: ModelArgs, public bidirectional: boolean) {
super();
this.embeddings = new nn.Embedding(args.relativeAttentionNumBuckets, args.numHeads);
}

forward(queryLength: number, keyLength: number, offset = 0) {
forward(embeddings: nn.Embedding, queryLength: number, keyLength: number, offset = 0) {
const contextPosition = mx.arange(offset, queryLength, 1, mx.int16).index(mx.Slice(), mx.newaxis);
const memoryPosition = mx.arange(keyLength, mx.int16).index(mx.newaxis, mx.Slice());

Expand All @@ -75,7 +72,7 @@ class RelativeAttentionBias extends nn.Module {
this.bidirectional,
this.args.relativeAttentionNumBuckets,
this.args.relativeAttentionMaxDistance);
const values = this.embeddings.forward(relativePositionBucket);
const values = embeddings.forward(relativePositionBucket);
return values.transpose(2, 0, 1);
}

Expand Down Expand Up @@ -165,21 +162,18 @@ class DenseGatedActDense extends nn.Module {
class LayerFF extends nn.Module {
DenseReluDense: DenseActDense | DenseGatedActDense;
layerNorm: nn.RMSNorm;
dropout: nn.Dropout;

constructor(args: ModelArgs) {
super();
this.DenseReluDense = args.isGatedAct ? new DenseGatedActDense(args)
: new DenseActDense(args);
this.layerNorm = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.dropout = new nn.Dropout(args.dropoutRate);
}

forward(x: mx.array) {
let f = this.layerNorm.forward(x);
f = this.DenseReluDense.forward(f);
x = mx.add(x, this.dropout.forward(f));
return x;
let y = this.layerNorm.forward(x);
y = this.DenseReluDense.forward(y);
return y;
}
}

Expand All @@ -188,16 +182,17 @@ class Attention extends nn.Module {
k: nn.Linear;
v: nn.Linear;
o: nn.Linear;
relativeAttentionBias?: nn.Embedding;

constructor(public args: ModelArgs,
public isDecoder: boolean,
public hasRelativeAttentionBias = false) {
constructor(public args: ModelArgs, public hasRelativeAttentionBias = false) {
super();
const innderDim = args.numHeads * args.dKv;
this.q = new nn.Linear(args.dModel, innderDim, false);
this.k = new nn.Linear(args.dModel, innderDim, false);
this.v = new nn.Linear(args.dModel, innderDim, false);
this.o = new nn.Linear(innderDim, args.dModel, false);
if (hasRelativeAttentionBias)
this.relativeAttentionBias = new nn.Embedding(args.relativeAttentionNumBuckets, args.numHeads);
}

forward(queries: mx.array, keys: mx.array, values: mx.array, mask?: mx.array, cache?: BaseKVCache) {
Expand All @@ -218,7 +213,7 @@ class Attention extends nn.Module {

let scores = mx.matmul(queries, keys.transpose(0, 1, 3, 2));
if (mask)
scores = mx.add(scores, mask);
scores = mx.add(scores, mask.astype(scores.dtype));
scores = mx.softmax(scores.astype(mx.float32), -1).astype(scores.dtype);
const output = mx.matmul(scores, values).transpose(0, 2, 1, 3).reshape(B, L, -1);
return this.o.forward(output);
Expand All @@ -228,152 +223,134 @@ class Attention extends nn.Module {
class LayerSelfAttention extends nn.Module {
SelfAttention: Attention;
layerNorm: nn.RMSNorm;
dropout: nn.Dropout;

constructor(args: ModelArgs, isDecoder: boolean, hasRelativeAttentionBias = false) {
constructor(args: ModelArgs, hasRelativeAttentionBias = false) {
super();
this.SelfAttention = new Attention(args, isDecoder, hasRelativeAttentionBias);
this.SelfAttention = new Attention(args, hasRelativeAttentionBias);
this.layerNorm = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.dropout = new nn.Dropout(args.dropoutRate);
}

forward(x: mx.array, keys?: mx.array, values?: mx.array, mask?: mx.array, cache?: BaseKVCache) {
forward(x: mx.array, mask?: mx.array, cache?: BaseKVCache) {
let y = this.layerNorm.forward(x);
y = this.SelfAttention.forward(y, keys ?? y, values ?? y, mask, cache);
return mx.add(x, this.dropout.forward(y));
y = this.SelfAttention.forward(y, y, y, mask, cache);
return y;
}
}

class LayerCrossAttention extends nn.Module {
EncDecAttention: Attention;
layerNorm: nn.RMSNorm;
dropout: nn.Dropout;

constructor(args: ModelArgs) {
constructor(args: ModelArgs, hasRelativeAttentionBias = false) {
super();
this.EncDecAttention = new Attention(args, true, false);
this.EncDecAttention = new Attention(args, hasRelativeAttentionBias);
this.layerNorm = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.dropout = new nn.Dropout(args.dropoutRate);
}

forward(x: mx.array, keys: mx.array, values: mx.array, mask?: mx.array, cache?: BaseKVCache) {
forward(x: mx.array, memory: mx.array, mask?: mx.array, cache?: BaseKVCache) {
let y = this.layerNorm.forward(x);
y = this.EncDecAttention.forward(x, keys, values, mask, cache);
return mx.add(x, this.dropout.forward(y));
y = this.EncDecAttention.forward(y, memory, memory, mask, cache);
return y;
}
}

class EncoderBlock extends nn.Module {
attention: Attention;
ln1: nn.RMSNorm;
ln2: nn.RMSNorm;
dense: DenseActDense | DenseGatedActDense;
layer: [ LayerSelfAttention, LayerFF ];

constructor(args: ModelArgs) {
constructor(args: ModelArgs, hasRelativeAttentionBias = false) {
super();
this.attention = new Attention(args, false);
this.ln1 = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.ln2 = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.dense = args.isGatedAct ? new DenseGatedActDense(args)
: new DenseActDense(args);
this.layer = [
new LayerSelfAttention(args, hasRelativeAttentionBias),
new LayerFF(args),
];
}

forward(x: mx.array, mask?: mx.array) {
let y = this.ln1.forward(x);
y = this.attention.forward(y, y, y, mask);
let y = this.layer[0].forward(x, mask);
x = mx.add(x, y);
y = this.ln2.forward(x);
y = this.dense.forward(y);
y = this.layer[1].forward(x);
x = mx.add(x, y);
return x;
}
}

class Encoder extends nn.Module {
block: EncoderBlock[] = [];
ln: nn.RMSNorm;
finalLayerNorm: nn.RMSNorm;
relativeAttentionBias: RelativeAttentionBias;
dropout: nn.Dropout;

constructor(args: ModelArgs) {
super();
for (let i = 0; i < args.numLayers; ++i)
this.block.push(new EncoderBlock(args));
this.ln = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.block.push(new EncoderBlock(args, i == 0));
this.finalLayerNorm = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.relativeAttentionBias = new RelativeAttentionBias(args, true);
this.dropout = new nn.Dropout(args.dropoutRate);
}

forward(x: mx.array) {
const L = x.shape[1];
const positionBias = this.relativeAttentionBias.forward(L, L);
const embeddings = this.block[0].layer[0].SelfAttention.relativeAttentionBias;
const positionBias = this.relativeAttentionBias.forward(embeddings, L, L);
for (const layer of this.block)
x = layer.forward(x, positionBias);
x = this.ln.forward(x);
x = this.finalLayerNorm.forward(x);
x = this.dropout.forward(x);
return x;
}
}

class DecoderBlock extends nn.Module {
selfAttention: Attention;
crossAttention: Attention;
ln1: nn.RMSNorm;
ln2: nn.RMSNorm;
ln3: nn.RMSNorm;
dense: DenseActDense | DenseGatedActDense;
layer: [ LayerSelfAttention, LayerCrossAttention, LayerFF ];

constructor(args: ModelArgs) {
constructor(args: ModelArgs, hasRelativeAttentionBias = false) {
super();
this.selfAttention = new Attention(args, true);
this.crossAttention = new Attention(args, true);
this.ln1 = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.ln2 = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.ln3 = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.dense = args.isGatedAct ? new DenseGatedActDense(args)
: new DenseActDense(args);
this.layer = [
new LayerSelfAttention(args, hasRelativeAttentionBias),
new LayerCrossAttention(args, hasRelativeAttentionBias),
new LayerFF(args),
];
}

forward(x: mx.array, memory: mx.array, mask?: mx.array, memoryMask?: mx.array, cache?: BaseKVCache) {
let y = this.ln1.forward(x);
y = this.selfAttention.forward(y, y, y, mask, cache);
let y = this.layer[0].forward(x, mask, cache);
x = mx.add(x, y);
y = this.ln2.forward(x);
y = this.crossAttention.forward(y, memory, memory, memoryMask);
y = this.layer[1].forward(x, memory, memoryMask);
x = mx.add(x, y);
y = this.ln3.forward(x);
y = this.dense.forward(y);
y = this.layer[2].forward(x);
x = mx.add(x, y);
return x;
}
}

class Decoder extends nn.Module {
block: DecoderBlock[] = [];
ln: nn.RMSNorm;
finalLayerNorm: nn.RMSNorm;
relativeAttentionBias: RelativeAttentionBias;
dropout: nn.Dropout;

constructor(args: ModelArgs) {
super();
for (let i = 0; i < args.numDecoderLayers; ++i)
this.block.push(new DecoderBlock(args));
this.ln = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.block.push(new DecoderBlock(args, i == 0));
this.finalLayerNorm = new nn.RMSNorm(args.dModel, args.layerNormEpsilon);
this.relativeAttentionBias = new RelativeAttentionBias(args, false);
this.dropout = new nn.Dropout(args.dropoutRate);
}

forward(x: mx.array, memory: mx.array, mask?: mx.array, memoryMask?: mx.array, cache?: BaseKVCache[]) {
const offset = cache ? cache[0].offset : 0;
const T = offset + x.shape[1];
const positionBias = this.relativeAttentionBias.forward(T, T, offset);
const embeddings = this.block[0].layer[0].SelfAttention.relativeAttentionBias;
const positionBias = this.relativeAttentionBias.forward(embeddings, T, T, offset);
if (mask)
mask = mx.add(mask, positionBias);
else
mask = positionBias;
for (let i in this.block)
x = this.block[i].forward(x, memory, mask, memoryMask, cache ? cache[i] : undefined);
x = this.ln.forward(x);
x = this.finalLayerNorm.forward(x);
x = this.dropout.forward(x);
return x;
}
Expand Down

0 comments on commit f87cda2

Please sign in to comment.