Skip to content

Commit

Permalink
Add support for gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Oct 21, 2024
1 parent df015a2 commit 13ba668
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ jobs:
huggingface download --silent Xenova/t5-small --hf # fix tokenizer
yarn tsx src/generate.ts --temperature=0 t5-small 'translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast.'
huggingface download --silent mlx-community/codegemma-1.1-2b-8bit
yarn tsx src/generate.ts --max-tokens=128 codegemma-1.1-2b-8bit
publish:
if: startsWith(github.ref, 'refs/tags/')
needs: [build]
Expand Down
11 changes: 11 additions & 0 deletions src/models/gemma.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import * as llama from './llama.js';

export class Model extends llama.Model {
constructor(json: any) {
if (json.force_use_exact_gelu)
json.hidden_act = 'gelu';
else if (!json.hidden_act || json.hidden_act == 'gelu_pytorch_tanh')
json.hidden_act = 'geluApprox';
super(json);
}
}
40 changes: 31 additions & 9 deletions src/models/llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ export interface RopeScaling {
}

export interface ModelArgs {
modelType: 'llama';
modelType: 'gemma' | 'llama' | 'qwen2';
attentionBias: boolean;
attentionOutProjectionBias: boolean;
headDim?: number;
hiddenAct: 'gelu' | 'geluApprox' | 'silu';
hiddenSize: number;
intermediateSize: number;
maxPositionEmbeddings: number;
Expand All @@ -39,12 +40,12 @@ export function modelArgs(json: any): ModelArgs {
ropeTraditional: false,
tieWordEmbeddings: true,
}, baseModelArgs(json));
if (args.attentionOutProjectionBias == undefined) {
if (args.attentionOutProjectionBias === undefined)
args.attentionOutProjectionBias = args.attentionBias;
}
if (!args.numKeyValueHeads) {
if (args.hiddenAct == undefined)
args.hiddenAct = 'silu';
if (args.numKeyValueHeads === undefined)
args.numKeyValueHeads = args.numAttentionHeads;
}
if (args.ropeScaling) {
if (!args.ropeScaling.factor)
throw new Error('rope_scaling must contain "factor"')
Expand All @@ -57,6 +58,16 @@ export function modelArgs(json: any): ModelArgs {
return args;
}

export class GemmaRMSNorm extends nn.RMSNorm {
constructor(dims: number, eps: number) {
super(dims, eps);
}

override forward(x: mx.array): mx.array {
return mx.fast.rmsNorm(x, mx.add(1, this.weight), this.eps);
}
}

class DynamicNTKScalingRoPE extends nn.Module {
#freqs?: mx.array;

Expand Down Expand Up @@ -191,6 +202,7 @@ class MLP extends nn.Module {
gateProj: nn.Linear;
downProj: nn.Linear;
upProj: nn.Linear;
activation: (x: mx.array) => mx.array;

constructor(args: ModelArgs) {
super();
Expand All @@ -202,10 +214,11 @@ class MLP extends nn.Module {
this.gateProj = new nn.Linear(dim, hiddenDim, mlpBias);
this.downProj = new nn.Linear(hiddenDim, dim, mlpBias);
this.upProj = new nn.Linear(dim, hiddenDim, mlpBias);
this.activation = nn[args.hiddenAct as 'silu'];
}

forward(x: mx.array) {
return this.downProj.forward(mx.multiply(nn.silu(this.gateProj.forward(x)),
return this.downProj.forward(mx.multiply(this.activation(this.gateProj.forward(x)),
this.upProj.forward(x)));
}
}
Expand All @@ -220,8 +233,9 @@ class TransformerBlock extends nn.Module {
super();
this.selfAttn = new Attention(args);
this.mlp = new MLP(args);
this.inputLayernorm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps);
this.postAttentionLayernorm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps);
const norm = args.modelType == 'gemma' ? GemmaRMSNorm : nn.RMSNorm;
this.inputLayernorm = new norm(args.hiddenSize, args.rmsNormEps);
this.postAttentionLayernorm = new norm(args.hiddenSize, args.rmsNormEps);
}

forward(x: mx.array, mask: mx.array, cache?: BaseKVCache) {
Expand All @@ -236,18 +250,26 @@ class LlamaModel extends nn.Module {
embedTokens: nn.Embedding;
layers: TransformerBlock[];
norm: nn.RMSNorm;
gemmaNormalizer?: number;

constructor(args: ModelArgs) {
super();
this.embedTokens = new nn.Embedding(args.vocabSize, args.hiddenSize);
this.layers = [];
for (let i = 0; i < args.numHiddenLayers; ++i)
this.layers.push(new TransformerBlock(args));
this.norm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps);
if (args.modelType == 'gemma') {
this.norm = new GemmaRMSNorm(args.hiddenSize, args.rmsNormEps);
this.gemmaNormalizer = args.hiddenSize ** 0.5;
} else {
this.norm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps);
}
}

forward(embeddings: mx.array, cache?: BaseKVCache[]) {
let h = embeddings;
if (this.gemmaNormalizer !== undefined)
h = mx.multiply(h, this.gemmaNormalizer);
const mask = createAttentionMask(h, cache);
for (let i in this.layers)
h = this.layers[i].forward(h, mask, cache ? cache[i] : undefined);
Expand Down

0 comments on commit 13ba668

Please sign in to comment.